Model name passing for summary, sentiment, ner (#125)

* pass model names in class init

* tests for model name and revision number passing

* add exception tests

* simplified selection logic

---------

Co-authored-by: Petr Andriushchenko <pitandmind@gmail.com>
Этот коммит содержится в:
Inga Ulusoy 2023-06-29 13:10:13 +02:00 коммит произвёл GitHub
родитель 2a87bc57fd
Коммит 9f025094a3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 157 добавлений и 16 удалений

Просмотреть файл

@ -1,5 +1,4 @@
import pytest
import spacy
import ammico.text as tt
@ -31,6 +30,49 @@ def test_TextDetector(set_testdict):
assert not test_obj.analyse_text
def test_init_revision_numbers_and_models():
test_obj = tt.TextDetector({})
# check the default options
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
assert test_obj.revision_summary == "a4f8f3e"
assert test_obj.revision_sentiment == "af0f99b"
assert test_obj.revision_ner == "f2482bf"
# provide non-default options
model_names = ["facebook/bart-large-cnn", None, None]
test_obj = tt.TextDetector({}, model_names=model_names)
assert test_obj.model_summary == "facebook/bart-large-cnn"
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
assert not test_obj.revision_summary
assert test_obj.revision_sentiment == "af0f99b"
assert test_obj.revision_ner == "f2482bf"
revision_numbers = ["3d22493", None, None]
test_obj = tt.TextDetector(
{},
model_names=model_names,
revision_numbers=revision_numbers,
)
assert test_obj.model_summary == "facebook/bart-large-cnn"
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
assert test_obj.revision_summary == "3d22493"
assert test_obj.revision_sentiment == "af0f99b"
assert test_obj.revision_ner == "f2482bf"
# now test the exceptions
with pytest.raises(ValueError):
tt.TextDetector({}, analyse_text=1.0)
with pytest.raises(ValueError):
tt.TextDetector({}, model_names=1.0)
with pytest.raises(ValueError):
tt.TextDetector({}, revision_numbers=1.0)
with pytest.raises(ValueError):
tt.TextDetector({}, model_names=["something"])
with pytest.raises(ValueError):
tt.TextDetector({}, revision_numbers=["something"])
@pytest.mark.gcv
def test_analyse_image(set_testdict, set_environ):
for item in set_testdict:

Просмотреть файл

@ -11,19 +11,124 @@ from transformers import pipeline
class TextDetector(AnalysisMethod):
def __init__(self, subdict: dict, analyse_text: bool = False) -> None:
def __init__(
self,
subdict: dict,
analyse_text: bool = False,
model_names: list = None,
revision_numbers: list = None,
) -> None:
"""Init text detection class.
Args:
subdict (dict): Dictionary containing file name/path, and possibly previous
analysis results from other modules.
analysis results from other modules.
analyse_text (bool, optional): Decide if extracted text will be further subject
to analysis. Defaults to False.
to analysis. Defaults to False.
model_names (list, optional): Provide model names for summary, sentiment and ner
analysis. Defaults to None, in which case the default model from transformers
are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary),
"distilbert-base-uncased-finetuned-sst-2-english" (sentiment),
"dbmdz/bert-large-cased-finetuned-conll03-english".
To select other models, provide a list with three entries, the first for
summary, second for sentiment, third for NER, with the desired model names.
Set one of these to None to still use the default model.
revision_numbers (list, optional): Model revision (commit) numbers on the
Hugging Face hub. Provide this to make sure you are using the same model.
Defaults to None, except if the default models are used; then it defaults to
"a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert),
"f2482bf" (NER, bert).
"""
super().__init__(subdict)
self.subdict.update(self.set_keys())
self.translator = Translator()
if not isinstance(analyse_text, bool):
raise ValueError("analyse_text needs to be set to true or false")
self.analyse_text = analyse_text
if model_names:
self._check_valid_models(model_names)
if revision_numbers:
self._check_revision_numbers(revision_numbers)
# initialize revision numbers and models
self._init_revision_numbers(model_names, revision_numbers)
self._init_model(model_names)
def _check_valid_models(self, model_names):
# check that model_names and revision_numbers are valid lists or None
# check that model names are a list
if not isinstance(model_names, list):
raise ValueError("Model names need to be provided as a list!")
# check that enough models are provided, one for each method
if len(model_names) != 3:
raise ValueError(
"Not enough or too many model names provided - three are required, one each for summary, sentiment, ner"
)
def _check_revision_numbers(self, revision_numbers):
# check that revision numbers are list
if not isinstance(revision_numbers, list):
raise ValueError("Revision numbers need to be provided as a list!")
# check that three revision numbers are provided, one for each method
if len(revision_numbers) != 3:
raise ValueError(
"Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner"
)
def _init_revision_numbers(self, model_names, revision_numbers):
"""Helper method to set the revision (version) number for each model."""
revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"]
if model_names:
# if model_names is provided, set revision numbers for each of the methods
# either as the provided revision number or None or as the default revision number,
# if one of the methods uses the default model
self._init_revision_numbers_per_model(
model_names, revision_numbers, revision_numbers_default
)
else:
# model_names was not provided, revision numbers are the default revision numbers or None
self.revision_summary = revision_numbers_default[0]
self.revision_sentiment = revision_numbers_default[1]
self.revision_ner = revision_numbers_default[2]
def _init_revision_numbers_per_model(
self, model_names, revision_numbers, revision_numbers_default
):
task_list = []
if not revision_numbers:
# no revision numbers for non-default models provided
revision_numbers = [None, None, None]
for model, revision, revision_default in zip(
model_names, revision_numbers, revision_numbers_default
):
# a model was specified for this task, set specified revision number or None
# or: model for this task was set to None, so we take default version number for default model
task_list.append(revision if model else revision_default)
self.revision_summary = task_list[0]
self.revision_sentiment = task_list[1]
self.revision_ner = task_list[2]
def _init_model(self, model_names):
"""Helper method to set the model name for each analysis method."""
# assign models for each of the text analysis methods
# and check that they are valid
model_names_default = [
"sshleifer/distilbart-cnn-12-6",
"distilbert-base-uncased-finetuned-sst-2-english",
"dbmdz/bert-large-cased-finetuned-conll03-english",
]
# no model names provided, set the default
if not model_names:
model_names = model_names_default
# now assign model names for each of the methods
# either to the provided model name or the default if one of the
# task's models is set to None
self.model_summary = (
model_names[0] if model_names[0] else model_names_default[0]
)
self.model_sentiment = (
model_names[1] if model_names[1] else model_names_default[1]
)
self.model_ner = model_names[2] if model_names[2] else model_names_default[2]
def set_keys(self) -> dict:
"""Set the default keys for text analysis.
@ -99,13 +204,11 @@ class TextDetector(AnalysisMethod):
"""Generate a summary of the text using the Transformers pipeline."""
# use the transformers pipeline to summarize the text
# use the current default model - 03/2023
model_name = "sshleifer/distilbart-cnn-12-6"
model_revision = "a4f8f3e"
max_number_of_characters = 3000
pipe = pipeline(
"summarization",
model=model_name,
revision=model_revision,
model=self.model_summary,
revision=self.revision_summary,
min_length=5,
max_length=20,
)
@ -123,12 +226,10 @@ class TextDetector(AnalysisMethod):
"""Perform text classification for sentiment using the Transformers pipeline."""
# use the transformers pipeline for text classification
# use the current default model - 03/2023
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model_revision = "af0f99b"
pipe = pipeline(
"text-classification",
model=model_name,
revision=model_revision,
model=self.model_sentiment,
revision=self.revision_sentiment,
truncation=True,
)
result = pipe(self.subdict["text_english"])
@ -139,12 +240,10 @@ class TextDetector(AnalysisMethod):
"""Perform named entity recognition on the text using the Transformers pipeline."""
# use the transformers pipeline for named entity recognition
# use the current default model - 03/2023
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
model_revision = "f2482bf"
pipe = pipeline(
"token-classification",
model=model_name,
revision=model_revision,
model=self.model_ner,
revision=self.revision_ner,
aggregation_strategy="simple",
)
result = pipe(self.subdict["text_english"])