зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
Model name passing for summary, sentiment, ner (#125)
* pass model names in class init * tests for model name and revision number passing * add exception tests * simplified selection logic --------- Co-authored-by: Petr Andriushchenko <pitandmind@gmail.com>
Этот коммит содержится в:
родитель
2a87bc57fd
Коммит
9f025094a3
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import spacy
|
||||
import ammico.text as tt
|
||||
|
||||
|
||||
@ -31,6 +30,49 @@ def test_TextDetector(set_testdict):
|
||||
assert not test_obj.analyse_text
|
||||
|
||||
|
||||
def test_init_revision_numbers_and_models():
|
||||
test_obj = tt.TextDetector({})
|
||||
# check the default options
|
||||
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
|
||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||
assert test_obj.revision_summary == "a4f8f3e"
|
||||
assert test_obj.revision_sentiment == "af0f99b"
|
||||
assert test_obj.revision_ner == "f2482bf"
|
||||
# provide non-default options
|
||||
model_names = ["facebook/bart-large-cnn", None, None]
|
||||
test_obj = tt.TextDetector({}, model_names=model_names)
|
||||
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||
assert not test_obj.revision_summary
|
||||
assert test_obj.revision_sentiment == "af0f99b"
|
||||
assert test_obj.revision_ner == "f2482bf"
|
||||
revision_numbers = ["3d22493", None, None]
|
||||
test_obj = tt.TextDetector(
|
||||
{},
|
||||
model_names=model_names,
|
||||
revision_numbers=revision_numbers,
|
||||
)
|
||||
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||
assert test_obj.revision_summary == "3d22493"
|
||||
assert test_obj.revision_sentiment == "af0f99b"
|
||||
assert test_obj.revision_ner == "f2482bf"
|
||||
# now test the exceptions
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, analyse_text=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, model_names=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, revision_numbers=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, model_names=["something"])
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, revision_numbers=["something"])
|
||||
|
||||
|
||||
@pytest.mark.gcv
|
||||
def test_analyse_image(set_testdict, set_environ):
|
||||
for item in set_testdict:
|
||||
|
||||
129
ammico/text.py
129
ammico/text.py
@ -11,19 +11,124 @@ from transformers import pipeline
|
||||
|
||||
|
||||
class TextDetector(AnalysisMethod):
|
||||
def __init__(self, subdict: dict, analyse_text: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
subdict: dict,
|
||||
analyse_text: bool = False,
|
||||
model_names: list = None,
|
||||
revision_numbers: list = None,
|
||||
) -> None:
|
||||
"""Init text detection class.
|
||||
|
||||
Args:
|
||||
subdict (dict): Dictionary containing file name/path, and possibly previous
|
||||
analysis results from other modules.
|
||||
analysis results from other modules.
|
||||
analyse_text (bool, optional): Decide if extracted text will be further subject
|
||||
to analysis. Defaults to False.
|
||||
to analysis. Defaults to False.
|
||||
model_names (list, optional): Provide model names for summary, sentiment and ner
|
||||
analysis. Defaults to None, in which case the default model from transformers
|
||||
are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary),
|
||||
"distilbert-base-uncased-finetuned-sst-2-english" (sentiment),
|
||||
"dbmdz/bert-large-cased-finetuned-conll03-english".
|
||||
To select other models, provide a list with three entries, the first for
|
||||
summary, second for sentiment, third for NER, with the desired model names.
|
||||
Set one of these to None to still use the default model.
|
||||
revision_numbers (list, optional): Model revision (commit) numbers on the
|
||||
Hugging Face hub. Provide this to make sure you are using the same model.
|
||||
Defaults to None, except if the default models are used; then it defaults to
|
||||
"a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert),
|
||||
"f2482bf" (NER, bert).
|
||||
"""
|
||||
super().__init__(subdict)
|
||||
self.subdict.update(self.set_keys())
|
||||
self.translator = Translator()
|
||||
if not isinstance(analyse_text, bool):
|
||||
raise ValueError("analyse_text needs to be set to true or false")
|
||||
self.analyse_text = analyse_text
|
||||
if model_names:
|
||||
self._check_valid_models(model_names)
|
||||
if revision_numbers:
|
||||
self._check_revision_numbers(revision_numbers)
|
||||
# initialize revision numbers and models
|
||||
self._init_revision_numbers(model_names, revision_numbers)
|
||||
self._init_model(model_names)
|
||||
|
||||
def _check_valid_models(self, model_names):
|
||||
# check that model_names and revision_numbers are valid lists or None
|
||||
# check that model names are a list
|
||||
if not isinstance(model_names, list):
|
||||
raise ValueError("Model names need to be provided as a list!")
|
||||
# check that enough models are provided, one for each method
|
||||
if len(model_names) != 3:
|
||||
raise ValueError(
|
||||
"Not enough or too many model names provided - three are required, one each for summary, sentiment, ner"
|
||||
)
|
||||
|
||||
def _check_revision_numbers(self, revision_numbers):
|
||||
# check that revision numbers are list
|
||||
if not isinstance(revision_numbers, list):
|
||||
raise ValueError("Revision numbers need to be provided as a list!")
|
||||
# check that three revision numbers are provided, one for each method
|
||||
if len(revision_numbers) != 3:
|
||||
raise ValueError(
|
||||
"Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner"
|
||||
)
|
||||
|
||||
def _init_revision_numbers(self, model_names, revision_numbers):
|
||||
"""Helper method to set the revision (version) number for each model."""
|
||||
revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"]
|
||||
if model_names:
|
||||
# if model_names is provided, set revision numbers for each of the methods
|
||||
# either as the provided revision number or None or as the default revision number,
|
||||
# if one of the methods uses the default model
|
||||
self._init_revision_numbers_per_model(
|
||||
model_names, revision_numbers, revision_numbers_default
|
||||
)
|
||||
else:
|
||||
# model_names was not provided, revision numbers are the default revision numbers or None
|
||||
self.revision_summary = revision_numbers_default[0]
|
||||
self.revision_sentiment = revision_numbers_default[1]
|
||||
self.revision_ner = revision_numbers_default[2]
|
||||
|
||||
def _init_revision_numbers_per_model(
|
||||
self, model_names, revision_numbers, revision_numbers_default
|
||||
):
|
||||
task_list = []
|
||||
if not revision_numbers:
|
||||
# no revision numbers for non-default models provided
|
||||
revision_numbers = [None, None, None]
|
||||
for model, revision, revision_default in zip(
|
||||
model_names, revision_numbers, revision_numbers_default
|
||||
):
|
||||
# a model was specified for this task, set specified revision number or None
|
||||
# or: model for this task was set to None, so we take default version number for default model
|
||||
task_list.append(revision if model else revision_default)
|
||||
self.revision_summary = task_list[0]
|
||||
self.revision_sentiment = task_list[1]
|
||||
self.revision_ner = task_list[2]
|
||||
|
||||
def _init_model(self, model_names):
|
||||
"""Helper method to set the model name for each analysis method."""
|
||||
# assign models for each of the text analysis methods
|
||||
# and check that they are valid
|
||||
model_names_default = [
|
||||
"sshleifer/distilbart-cnn-12-6",
|
||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
]
|
||||
# no model names provided, set the default
|
||||
if not model_names:
|
||||
model_names = model_names_default
|
||||
# now assign model names for each of the methods
|
||||
# either to the provided model name or the default if one of the
|
||||
# task's models is set to None
|
||||
self.model_summary = (
|
||||
model_names[0] if model_names[0] else model_names_default[0]
|
||||
)
|
||||
self.model_sentiment = (
|
||||
model_names[1] if model_names[1] else model_names_default[1]
|
||||
)
|
||||
self.model_ner = model_names[2] if model_names[2] else model_names_default[2]
|
||||
|
||||
def set_keys(self) -> dict:
|
||||
"""Set the default keys for text analysis.
|
||||
@ -99,13 +204,11 @@ class TextDetector(AnalysisMethod):
|
||||
"""Generate a summary of the text using the Transformers pipeline."""
|
||||
# use the transformers pipeline to summarize the text
|
||||
# use the current default model - 03/2023
|
||||
model_name = "sshleifer/distilbart-cnn-12-6"
|
||||
model_revision = "a4f8f3e"
|
||||
max_number_of_characters = 3000
|
||||
pipe = pipeline(
|
||||
"summarization",
|
||||
model=model_name,
|
||||
revision=model_revision,
|
||||
model=self.model_summary,
|
||||
revision=self.revision_summary,
|
||||
min_length=5,
|
||||
max_length=20,
|
||||
)
|
||||
@ -123,12 +226,10 @@ class TextDetector(AnalysisMethod):
|
||||
"""Perform text classification for sentiment using the Transformers pipeline."""
|
||||
# use the transformers pipeline for text classification
|
||||
# use the current default model - 03/2023
|
||||
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
|
||||
model_revision = "af0f99b"
|
||||
pipe = pipeline(
|
||||
"text-classification",
|
||||
model=model_name,
|
||||
revision=model_revision,
|
||||
model=self.model_sentiment,
|
||||
revision=self.revision_sentiment,
|
||||
truncation=True,
|
||||
)
|
||||
result = pipe(self.subdict["text_english"])
|
||||
@ -139,12 +240,10 @@ class TextDetector(AnalysisMethod):
|
||||
"""Perform named entity recognition on the text using the Transformers pipeline."""
|
||||
# use the transformers pipeline for named entity recognition
|
||||
# use the current default model - 03/2023
|
||||
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||
model_revision = "f2482bf"
|
||||
pipe = pipeline(
|
||||
"token-classification",
|
||||
model=model_name,
|
||||
revision=model_revision,
|
||||
model=self.model_ner,
|
||||
revision=self.revision_ner,
|
||||
aggregation_strategy="simple",
|
||||
)
|
||||
result = pipe(self.subdict["text_english"])
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user