diff --git a/ammico/test/test_text.py b/ammico/test/test_text.py index da0d642..a7cedd6 100644 --- a/ammico/test/test_text.py +++ b/ammico/test/test_text.py @@ -55,15 +55,6 @@ def test_TextDetector(set_testdict, accepted): assert not test_obj.analyse_text assert not test_obj.skip_extraction assert test_obj.subdict["filename"] == set_testdict[item]["filename"] - assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6" - assert ( - test_obj.model_sentiment - == "distilbert-base-uncased-finetuned-sst-2-english" - ) - assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" - assert test_obj.revision_summary == "a4f8f3e" - assert test_obj.revision_sentiment == "af0f99b" - assert test_obj.revision_ner == "f2482bf" test_obj = tt.TextDetector( {}, analyse_text=True, skip_extraction=True, accept_privacy=accepted ) @@ -97,50 +88,6 @@ def test_clean_text(set_testdict, accepted): assert test_obj.subdict["text_clean"] == result -def test_init_revision_numbers_and_models(accepted): - test_obj = tt.TextDetector({}, accept_privacy=accepted) - # check the default options - assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6" - assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" - assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" - assert test_obj.revision_summary == "a4f8f3e" - assert test_obj.revision_sentiment == "af0f99b" - assert test_obj.revision_ner == "f2482bf" - # provide non-default options - model_names = ["facebook/bart-large-cnn", None, None] - test_obj = tt.TextDetector({}, model_names=model_names, accept_privacy=accepted) - assert test_obj.model_summary == "facebook/bart-large-cnn" - assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" - assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" - assert not test_obj.revision_summary - assert test_obj.revision_sentiment == "af0f99b" - assert test_obj.revision_ner == "f2482bf" - revision_numbers = ["3d22493", None, None] - test_obj = tt.TextDetector( - {}, - model_names=model_names, - revision_numbers=revision_numbers, - accept_privacy=accepted, - ) - assert test_obj.model_summary == "facebook/bart-large-cnn" - assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" - assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" - assert test_obj.revision_summary == "3d22493" - assert test_obj.revision_sentiment == "af0f99b" - assert test_obj.revision_ner == "f2482bf" - # now test the exceptions - with pytest.raises(ValueError): - tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted) - with pytest.raises(ValueError): - tt.TextDetector({}, model_names=1.0, accept_privacy=accepted) - with pytest.raises(ValueError): - tt.TextDetector({}, revision_numbers=1.0, accept_privacy=accepted) - with pytest.raises(ValueError): - tt.TextDetector({}, model_names=["something"], accept_privacy=accepted) - with pytest.raises(ValueError): - tt.TextDetector({}, revision_numbers=["something"], accept_privacy=accepted) - - def test_check_add_space_after_full_stop(accepted): test_obj = tt.TextDetector({}, accept_privacy=accepted) test_obj.subdict["text"] = "I like cats. I like dogs." @@ -153,7 +100,6 @@ def test_check_add_space_after_full_stop(accepted): test_obj._check_add_space_after_full_stop() assert test_obj.subdict["text"] == "www. icanhascheezburger. com" - def test_truncate_text(accepted): test_obj = tt.TextDetector({}, accept_privacy=accepted) test_obj.subdict["text"] = "I like cats and dogs." @@ -165,7 +111,6 @@ def test_truncate_text(accepted): assert test_obj.subdict["text_truncated"] == 5000 * "m" assert test_obj.subdict["text"] == 20000 * "m" - @pytest.mark.gcv def test_analyse_image(set_testdict, set_environ, accepted): for item in set_testdict: @@ -222,36 +167,6 @@ def test_remove_linebreaks(accepted): assert test_obj.subdict["text_english"] == "This is another test." -def test_text_summary(get_path, accepted): - mydict = {} - test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted) - ref_file = get_path + "example_summary.txt" - with open(ref_file, "r", encoding="utf8") as file: - reference_text = file.read() - mydict["text_english"] = reference_text - test_obj.text_summary() - reference_summary = " I’m sorry, but I don’t want to be an emperor" - assert mydict["text_summary"] == reference_summary - - -def test_text_sentiment_transformers(accepted): - mydict = {} - test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted) - mydict["text_english"] = "I am happy that the CI is working again." - test_obj.text_sentiment_transformers() - assert mydict["sentiment"] == "POSITIVE" - assert mydict["sentiment_score"] == pytest.approx(0.99, 0.02) - - -def test_text_ner(accepted): - mydict = {} - test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted) - mydict["text_english"] = "Bill Gates was born in Seattle." - test_obj.text_ner() - assert mydict["entity"] == ["Bill Gates", "Seattle"] - assert mydict["entity_type"] == ["PER", "LOC"] - - def test_init_csv_option(get_path): test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv") assert test_obj.csv_path == get_path + "test.csv" @@ -295,39 +210,3 @@ def test_read_csv(get_path): test_obj.mydict.items(), ref_dict.items() ): assert value_test["text"] == value_ref["text"] - - -def test_PostprocessText(set_testdict, get_path): - reference_dict = "THE ALGEBRAIC EIGENVALUE PROBLEM" - reference_df = "Mathematische Formelsammlung\nfür Ingenieure und Naturwissenschaftler\nMit zahlreichen Abbildungen und Rechenbeispielen\nund einer ausführlichen Integraltafel\n3., verbesserte Auflage" - img_numbers = ["IMG_3755", "IMG_3756", "IMG_3757"] - for image_ref in img_numbers: - ref_file = get_path + "text_" + image_ref + ".txt" - with open(ref_file, "r") as file: - reference_text = file.read() - set_testdict[image_ref]["text_english"] = reference_text - obj = tt.PostprocessText(mydict=set_testdict) - test_dict = obj.list_text_english[2].replace("\r", "") - assert test_dict == reference_dict - for key in set_testdict.keys(): - set_testdict[key].pop("text_english") - with pytest.raises(ValueError): - tt.PostprocessText(mydict=set_testdict) - obj = tt.PostprocessText(use_csv=True, csv_path=get_path + "test_data_out.csv") - # make sure test works on windows where end-of-line character is \r\n - test_df = obj.list_text_english[0].replace("\r", "") - assert test_df == reference_df - with pytest.raises(ValueError): - tt.PostprocessText(use_csv=True, csv_path=get_path + "test_data_out_nokey.csv") - with pytest.raises(ValueError): - tt.PostprocessText() - - -def test_analyse_topic(get_path): - _, topic_df, most_frequent_topics = tt.PostprocessText( - use_csv=True, csv_path=get_path + "topic_analysis_test.csv" - ).analyse_topic() - # since this is not deterministic we cannot be sure we get the same result twice - assert len(topic_df) == 2 - assert topic_df["Name"].iloc[0] == "0_the_feat_of_is" - assert most_frequent_topics[0][0][0] == "the" diff --git a/ammico/text.py b/ammico/text.py index c6c83fc..cbcb280 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -8,8 +8,6 @@ import re from ammico.utils import AnalysisMethod import grpc import pandas as pd -from bertopic import BERTopic -from transformers import pipeline PRIVACY_STATEMENT = """The Text Detector uses Google Cloud Vision and Google Translate. Detailed information about how information @@ -71,8 +69,6 @@ class TextDetector(AnalysisMethod): subdict: dict, analyse_text: bool = False, skip_extraction: bool = False, - model_names: list = None, - revision_numbers: list = None, accept_privacy: str = "PRIVACY_AMMICO", ) -> None: """Init text detection class. @@ -84,19 +80,6 @@ class TextDetector(AnalysisMethod): to analysis. Defaults to False. skip_extraction (bool, optional): Decide if text will be extracted from images or is already provided via a csv. Defaults to False. - model_names (list, optional): Provide model names for summary, sentiment and ner - analysis. Defaults to None, in which case the default model from transformers - are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary), - "distilbert-base-uncased-finetuned-sst-2-english" (sentiment), - "dbmdz/bert-large-cased-finetuned-conll03-english". - To select other models, provide a list with three entries, the first for - summary, second for sentiment, third for NER, with the desired model names. - Set one of these to None to still use the default model. - revision_numbers (list, optional): Model revision (commit) numbers on the - Hugging Face hub. Provide this to make sure you are using the same model. - Defaults to None, except if the default models are used; then it defaults to - "a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert), - "f2482bf" (NER, bert). accept_privacy (str, optional): Environment variable to accept the privacy statement for the Google Cloud processing of the data. Defaults to "PRIVACY_AMMICO". @@ -124,90 +107,6 @@ class TextDetector(AnalysisMethod): print("Reading text directly from provided dictionary.") if self.analyse_text: self._initialize_spacy() - if model_names: - self._check_valid_models(model_names) - if revision_numbers: - self._check_revision_numbers(revision_numbers) - # initialize revision numbers and models - self._init_revision_numbers(model_names, revision_numbers) - self._init_model(model_names) - - def _check_valid_models(self, model_names): - # check that model_names and revision_numbers are valid lists or None - # check that model names are a list - if not isinstance(model_names, list): - raise ValueError("Model names need to be provided as a list!") - # check that enough models are provided, one for each method - if len(model_names) != 3: - raise ValueError( - "Not enough or too many model names provided - three are required, one each for summary, sentiment, ner" - ) - - def _check_revision_numbers(self, revision_numbers): - # check that revision numbers are list - if not isinstance(revision_numbers, list): - raise ValueError("Revision numbers need to be provided as a list!") - # check that three revision numbers are provided, one for each method - if len(revision_numbers) != 3: - raise ValueError( - "Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner" - ) - - def _init_revision_numbers(self, model_names, revision_numbers): - """Helper method to set the revision (version) number for each model.""" - revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"] - if model_names: - # if model_names is provided, set revision numbers for each of the methods - # either as the provided revision number or None or as the default revision number, - # if one of the methods uses the default model - self._init_revision_numbers_per_model( - model_names, revision_numbers, revision_numbers_default - ) - else: - # model_names was not provided, revision numbers are the default revision numbers or None - self.revision_summary = revision_numbers_default[0] - self.revision_sentiment = revision_numbers_default[1] - self.revision_ner = revision_numbers_default[2] - - def _init_revision_numbers_per_model( - self, model_names, revision_numbers, revision_numbers_default - ): - task_list = [] - if not revision_numbers: - # no revision numbers for non-default models provided - revision_numbers = [None, None, None] - for model, revision, revision_default in zip( - model_names, revision_numbers, revision_numbers_default - ): - # a model was specified for this task, set specified revision number or None - # or: model for this task was set to None, so we take default version number for default model - task_list.append(revision if model else revision_default) - self.revision_summary = task_list[0] - self.revision_sentiment = task_list[1] - self.revision_ner = task_list[2] - - def _init_model(self, model_names): - """Helper method to set the model name for each analysis method.""" - # assign models for each of the text analysis methods - # and check that they are valid - model_names_default = [ - "sshleifer/distilbart-cnn-12-6", - "distilbert-base-uncased-finetuned-sst-2-english", - "dbmdz/bert-large-cased-finetuned-conll03-english", - ] - # no model names provided, set the default - if not model_names: - model_names = model_names_default - # now assign model names for each of the methods - # either to the provided model name or the default if one of the - # task's models is set to None - self.model_summary = ( - model_names[0] if model_names[0] else model_names_default[0] - ) - self.model_sentiment = ( - model_names[1] if model_names[1] else model_names_default[1] - ) - self.model_ner = model_names[2] if model_names[2] else model_names_default[2] def set_keys(self) -> dict: """Set the default keys for text analysis. @@ -285,10 +184,6 @@ class TextDetector(AnalysisMethod): self.remove_linebreaks() if self.analyse_text and self.subdict["text_english"]: self._run_spacy() - self.clean_text() - self.text_summary() - self.text_sentiment_transformers() - self.text_ner() return self.subdict def get_text_from_image(self): @@ -362,73 +257,6 @@ class TextDetector(AnalysisMethod): """Generate Spacy doc object for further text analysis.""" self.doc = self.nlp(self.subdict["text_english"]) - def clean_text(self): - """Clean the text from unrecognized words and any numbers.""" - templist = [] - for token in self.doc: - ( - templist.append(token.text) - if token.pos_ != "NUM" and token.has_vector - else None - ) - self.subdict["text_clean"] = " ".join(templist).rstrip().lstrip() - - def text_summary(self): - """Generate a summary of the text using the Transformers pipeline.""" - # use the transformers pipeline to summarize the text - # use the current default model - 03/2023 - max_number_of_characters = 3000 - pipe = pipeline( - "summarization", - model=self.model_summary, - revision=self.revision_summary, - min_length=5, - max_length=20, - framework="pt", - ) - try: - summary = pipe(self.subdict["text_english"][0:max_number_of_characters]) - self.subdict["text_summary"] = summary[0]["summary_text"] - except IndexError: - print( - "Cannot provide summary for this object - please check that the text has been translated correctly." - ) - print("Image: {}".format(self.subdict["filename"])) - self.subdict["text_summary"] = None - - def text_sentiment_transformers(self): - """Perform text classification for sentiment using the Transformers pipeline.""" - # use the transformers pipeline for text classification - # use the current default model - 03/2023 - pipe = pipeline( - "text-classification", - model=self.model_sentiment, - revision=self.revision_sentiment, - truncation=True, - framework="pt", - ) - result = pipe(self.subdict["text_english"]) - self.subdict["sentiment"] = result[0]["label"] - self.subdict["sentiment_score"] = round(result[0]["score"], 2) - - def text_ner(self): - """Perform named entity recognition on the text using the Transformers pipeline.""" - # use the transformers pipeline for named entity recognition - # use the current default model - 03/2023 - pipe = pipeline( - "token-classification", - model=self.model_ner, - revision=self.revision_ner, - aggregation_strategy="simple", - framework="pt", - ) - result = pipe(self.subdict["text_english"]) - self.subdict["entity"] = [] - self.subdict["entity_type"] = [] - for entity in result: - self.subdict["entity"].append(entity["word"]) - self.subdict["entity_type"].append(entity["entity_group"]) - class TextAnalyzer: """Used to get text from a csv and then run the TextDetector on it.""" @@ -492,127 +320,3 @@ class TextAnalyzer: "filename": self.csv_path, "text": text, } - - -class PostprocessText: - def __init__( - self, - mydict: dict = None, - use_csv: bool = False, - csv_path: str = None, - analyze_text: str = "text_english", - ) -> None: - """ - Initializes the PostprocessText class that handles the topic analysis. - - Args: - mydict (dict, optional): Dictionary with textual data. Defaults to None. - use_csv (bool, optional): Flag indicating whether to use a CSV file. Defaults to False. - csv_path (str, optional): Path to the CSV file. Required if `use_csv` is True. Defaults to None. - analyze_text (str, optional): Key for the text field to analyze. Defaults to "text_english". - """ - self.use_csv = use_csv - if mydict: - print("Reading data from dict.") - self.mydict = mydict - self.list_text_english = self.get_text_dict(analyze_text) - elif self.use_csv: - print("Reading data from df.") - self.df = pd.read_csv(csv_path, encoding="utf8") - self.list_text_english = self.get_text_df(analyze_text) - else: - raise ValueError( - "Please provide either dictionary with textual data or \ - a csv file by setting `use_csv` to True and providing a \ - `csv_path`." - ) - # initialize spacy - self._initialize_spacy() - - def _initialize_spacy(self): - try: - self.nlp = spacy.load( - "en_core_web_md", - exclude=["tagger", "parser", "ner", "attribute_ruler", "lemmatizer"], - ) - except Exception: - spacy.cli.download("en_core_web_md") - self.nlp = spacy.load( - "en_core_web_md", - exclude=["tagger", "parser", "ner", "attribute_ruler", "lemmatizer"], - ) - - def analyse_topic(self, return_topics: int = 3) -> tuple: - """ - Performs topic analysis using BERTopic. - - Args: - return_topics (int, optional): Number of topics to return. Defaults to 3. - - Returns: - tuple: A tuple containing the topic model, topic dataframe, and most frequent topics. - """ - try: - # unfortunately catching exceptions does not work here - need to figure out why - self.topic_model = BERTopic(embedding_model=self.nlp) - except TypeError: - print("BERTopic excited with an error - maybe your dataset is too small?") - self.topics, self.probs = self.topic_model.fit_transform(self.list_text_english) - # return the topic list - topic_df = self.topic_model.get_topic_info() - # return the most frequent return_topics - most_frequent_topics = [] - if len(topic_df) < return_topics: - print("You requested more topics than are identified in your dataset -") - print( - "Returning only {} topics as these are all that have been found.".format( - len(topic_df) - ) - ) - for i in range(min(return_topics, len(topic_df))): - most_frequent_topics.append(self.topic_model.get_topic(i)) - return self.topic_model, topic_df, most_frequent_topics - - def get_text_dict(self, analyze_text: str) -> list: - """ - Extracts text from the provided dictionary. - - Args: - analyze_text (str): Key for the text field to analyze. - - Returns: - list: A list of text extracted from the dictionary. - """ - # use dict to put text_english or text_summary in list - list_text_english = [] - for key in self.mydict.keys(): - if analyze_text not in self.mydict[key]: - raise ValueError( - "Please check your provided dictionary - \ - no {} text data found.".format( - analyze_text - ) - ) - list_text_english.append(self.mydict[key][analyze_text]) - return list_text_english - - def get_text_df(self, analyze_text: str) -> list: - """ - Extracts text from the provided dataframe. - - Args: - analyze_text (str): Column name for the text field to analyze. - - Returns: - list: A list of text extracted from the dataframe. - """ - # use csv file to obtain dataframe and put text_english or text_summary in list - # check that "text_english" or "text_summary" is there - if analyze_text not in self.df: - raise ValueError( - "Please check your provided dataframe - \ - no {} text data found.".format( - analyze_text - ) - ) - return self.df[analyze_text].tolist() diff --git a/pyproject.toml b/pyproject.toml index 9e8b146..3ffe965 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ classifiers = [ ] dependencies = [ - "bertopic<=0.14.1", "dash>=2.11.0", "datasets", "deepface<=0.0.93", @@ -50,7 +49,6 @@ dependencies = [ "spacy<=3.7.5", "tensorflow>=2.13.0", "torch<2.6.0", - "transformers", "google-cloud-vision", "dash_bootstrap_components", "colorgram.py",