diff --git a/ammico/test/test_text.py b/ammico/test/test_text.py index cfa8e6b..dd0f9f1 100644 --- a/ammico/test/test_text.py +++ b/ammico/test/test_text.py @@ -1,5 +1,4 @@ import pytest -import spacy import ammico.text as tt @@ -31,6 +30,49 @@ def test_TextDetector(set_testdict): assert not test_obj.analyse_text +def test_init_revision_numbers_and_models(): + test_obj = tt.TextDetector({}) + # check the default options + assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6" + assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" + assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" + assert test_obj.revision_summary == "a4f8f3e" + assert test_obj.revision_sentiment == "af0f99b" + assert test_obj.revision_ner == "f2482bf" + # provide non-default options + model_names = ["facebook/bart-large-cnn", None, None] + test_obj = tt.TextDetector({}, model_names=model_names) + assert test_obj.model_summary == "facebook/bart-large-cnn" + assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" + assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" + assert not test_obj.revision_summary + assert test_obj.revision_sentiment == "af0f99b" + assert test_obj.revision_ner == "f2482bf" + revision_numbers = ["3d22493", None, None] + test_obj = tt.TextDetector( + {}, + model_names=model_names, + revision_numbers=revision_numbers, + ) + assert test_obj.model_summary == "facebook/bart-large-cnn" + assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english" + assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" + assert test_obj.revision_summary == "3d22493" + assert test_obj.revision_sentiment == "af0f99b" + assert test_obj.revision_ner == "f2482bf" + # now test the exceptions + with pytest.raises(ValueError): + tt.TextDetector({}, analyse_text=1.0) + with pytest.raises(ValueError): + tt.TextDetector({}, model_names=1.0) + with pytest.raises(ValueError): + tt.TextDetector({}, revision_numbers=1.0) + with pytest.raises(ValueError): + tt.TextDetector({}, model_names=["something"]) + with pytest.raises(ValueError): + tt.TextDetector({}, revision_numbers=["something"]) + + @pytest.mark.gcv def test_analyse_image(set_testdict, set_environ): for item in set_testdict: diff --git a/ammico/text.py b/ammico/text.py index dc21185..1964998 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -11,19 +11,124 @@ from transformers import pipeline class TextDetector(AnalysisMethod): - def __init__(self, subdict: dict, analyse_text: bool = False) -> None: + def __init__( + self, + subdict: dict, + analyse_text: bool = False, + model_names: list = None, + revision_numbers: list = None, + ) -> None: """Init text detection class. Args: subdict (dict): Dictionary containing file name/path, and possibly previous - analysis results from other modules. + analysis results from other modules. analyse_text (bool, optional): Decide if extracted text will be further subject - to analysis. Defaults to False. + to analysis. Defaults to False. + model_names (list, optional): Provide model names for summary, sentiment and ner + analysis. Defaults to None, in which case the default model from transformers + are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary), + "distilbert-base-uncased-finetuned-sst-2-english" (sentiment), + "dbmdz/bert-large-cased-finetuned-conll03-english". + To select other models, provide a list with three entries, the first for + summary, second for sentiment, third for NER, with the desired model names. + Set one of these to None to still use the default model. + revision_numbers (list, optional): Model revision (commit) numbers on the + Hugging Face hub. Provide this to make sure you are using the same model. + Defaults to None, except if the default models are used; then it defaults to + "a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert), + "f2482bf" (NER, bert). """ super().__init__(subdict) self.subdict.update(self.set_keys()) self.translator = Translator() + if not isinstance(analyse_text, bool): + raise ValueError("analyse_text needs to be set to true or false") self.analyse_text = analyse_text + if model_names: + self._check_valid_models(model_names) + if revision_numbers: + self._check_revision_numbers(revision_numbers) + # initialize revision numbers and models + self._init_revision_numbers(model_names, revision_numbers) + self._init_model(model_names) + + def _check_valid_models(self, model_names): + # check that model_names and revision_numbers are valid lists or None + # check that model names are a list + if not isinstance(model_names, list): + raise ValueError("Model names need to be provided as a list!") + # check that enough models are provided, one for each method + if len(model_names) != 3: + raise ValueError( + "Not enough or too many model names provided - three are required, one each for summary, sentiment, ner" + ) + + def _check_revision_numbers(self, revision_numbers): + # check that revision numbers are list + if not isinstance(revision_numbers, list): + raise ValueError("Revision numbers need to be provided as a list!") + # check that three revision numbers are provided, one for each method + if len(revision_numbers) != 3: + raise ValueError( + "Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner" + ) + + def _init_revision_numbers(self, model_names, revision_numbers): + """Helper method to set the revision (version) number for each model.""" + revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"] + if model_names: + # if model_names is provided, set revision numbers for each of the methods + # either as the provided revision number or None or as the default revision number, + # if one of the methods uses the default model + self._init_revision_numbers_per_model( + model_names, revision_numbers, revision_numbers_default + ) + else: + # model_names was not provided, revision numbers are the default revision numbers or None + self.revision_summary = revision_numbers_default[0] + self.revision_sentiment = revision_numbers_default[1] + self.revision_ner = revision_numbers_default[2] + + def _init_revision_numbers_per_model( + self, model_names, revision_numbers, revision_numbers_default + ): + task_list = [] + if not revision_numbers: + # no revision numbers for non-default models provided + revision_numbers = [None, None, None] + for model, revision, revision_default in zip( + model_names, revision_numbers, revision_numbers_default + ): + # a model was specified for this task, set specified revision number or None + # or: model for this task was set to None, so we take default version number for default model + task_list.append(revision if model else revision_default) + self.revision_summary = task_list[0] + self.revision_sentiment = task_list[1] + self.revision_ner = task_list[2] + + def _init_model(self, model_names): + """Helper method to set the model name for each analysis method.""" + # assign models for each of the text analysis methods + # and check that they are valid + model_names_default = [ + "sshleifer/distilbart-cnn-12-6", + "distilbert-base-uncased-finetuned-sst-2-english", + "dbmdz/bert-large-cased-finetuned-conll03-english", + ] + # no model names provided, set the default + if not model_names: + model_names = model_names_default + # now assign model names for each of the methods + # either to the provided model name or the default if one of the + # task's models is set to None + self.model_summary = ( + model_names[0] if model_names[0] else model_names_default[0] + ) + self.model_sentiment = ( + model_names[1] if model_names[1] else model_names_default[1] + ) + self.model_ner = model_names[2] if model_names[2] else model_names_default[2] def set_keys(self) -> dict: """Set the default keys for text analysis. @@ -99,13 +204,11 @@ class TextDetector(AnalysisMethod): """Generate a summary of the text using the Transformers pipeline.""" # use the transformers pipeline to summarize the text # use the current default model - 03/2023 - model_name = "sshleifer/distilbart-cnn-12-6" - model_revision = "a4f8f3e" max_number_of_characters = 3000 pipe = pipeline( "summarization", - model=model_name, - revision=model_revision, + model=self.model_summary, + revision=self.revision_summary, min_length=5, max_length=20, ) @@ -123,12 +226,10 @@ class TextDetector(AnalysisMethod): """Perform text classification for sentiment using the Transformers pipeline.""" # use the transformers pipeline for text classification # use the current default model - 03/2023 - model_name = "distilbert-base-uncased-finetuned-sst-2-english" - model_revision = "af0f99b" pipe = pipeline( "text-classification", - model=model_name, - revision=model_revision, + model=self.model_sentiment, + revision=self.revision_sentiment, truncation=True, ) result = pipe(self.subdict["text_english"]) @@ -139,12 +240,10 @@ class TextDetector(AnalysisMethod): """Perform named entity recognition on the text using the Transformers pipeline.""" # use the transformers pipeline for named entity recognition # use the current default model - 03/2023 - model_name = "dbmdz/bert-large-cased-finetuned-conll03-english" - model_revision = "f2482bf" pipe = pipeline( "token-classification", - model=model_name, - revision=model_revision, + model=self.model_ner, + revision=self.revision_ner, aggregation_strategy="simple", ) result = pipe(self.subdict["text_english"])