зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 05:26:05 +02:00
Model name passing for summary, sentiment, ner (#125)
* pass model names in class init * tests for model name and revision number passing * add exception tests * simplified selection logic --------- Co-authored-by: Petr Andriushchenko <pitandmind@gmail.com>
Этот коммит содержится в:
родитель
2a87bc57fd
Коммит
9f025094a3
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import spacy
|
|
||||||
import ammico.text as tt
|
import ammico.text as tt
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +30,49 @@ def test_TextDetector(set_testdict):
|
|||||||
assert not test_obj.analyse_text
|
assert not test_obj.analyse_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_revision_numbers_and_models():
|
||||||
|
test_obj = tt.TextDetector({})
|
||||||
|
# check the default options
|
||||||
|
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
|
||||||
|
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||||
|
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
|
assert test_obj.revision_summary == "a4f8f3e"
|
||||||
|
assert test_obj.revision_sentiment == "af0f99b"
|
||||||
|
assert test_obj.revision_ner == "f2482bf"
|
||||||
|
# provide non-default options
|
||||||
|
model_names = ["facebook/bart-large-cnn", None, None]
|
||||||
|
test_obj = tt.TextDetector({}, model_names=model_names)
|
||||||
|
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
||||||
|
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||||
|
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
|
assert not test_obj.revision_summary
|
||||||
|
assert test_obj.revision_sentiment == "af0f99b"
|
||||||
|
assert test_obj.revision_ner == "f2482bf"
|
||||||
|
revision_numbers = ["3d22493", None, None]
|
||||||
|
test_obj = tt.TextDetector(
|
||||||
|
{},
|
||||||
|
model_names=model_names,
|
||||||
|
revision_numbers=revision_numbers,
|
||||||
|
)
|
||||||
|
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
||||||
|
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
||||||
|
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
|
assert test_obj.revision_summary == "3d22493"
|
||||||
|
assert test_obj.revision_sentiment == "af0f99b"
|
||||||
|
assert test_obj.revision_ner == "f2482bf"
|
||||||
|
# now test the exceptions
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tt.TextDetector({}, analyse_text=1.0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tt.TextDetector({}, model_names=1.0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tt.TextDetector({}, revision_numbers=1.0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tt.TextDetector({}, model_names=["something"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tt.TextDetector({}, revision_numbers=["something"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gcv
|
@pytest.mark.gcv
|
||||||
def test_analyse_image(set_testdict, set_environ):
|
def test_analyse_image(set_testdict, set_environ):
|
||||||
for item in set_testdict:
|
for item in set_testdict:
|
||||||
|
|||||||
129
ammico/text.py
129
ammico/text.py
@ -11,19 +11,124 @@ from transformers import pipeline
|
|||||||
|
|
||||||
|
|
||||||
class TextDetector(AnalysisMethod):
|
class TextDetector(AnalysisMethod):
|
||||||
def __init__(self, subdict: dict, analyse_text: bool = False) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
subdict: dict,
|
||||||
|
analyse_text: bool = False,
|
||||||
|
model_names: list = None,
|
||||||
|
revision_numbers: list = None,
|
||||||
|
) -> None:
|
||||||
"""Init text detection class.
|
"""Init text detection class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subdict (dict): Dictionary containing file name/path, and possibly previous
|
subdict (dict): Dictionary containing file name/path, and possibly previous
|
||||||
analysis results from other modules.
|
analysis results from other modules.
|
||||||
analyse_text (bool, optional): Decide if extracted text will be further subject
|
analyse_text (bool, optional): Decide if extracted text will be further subject
|
||||||
to analysis. Defaults to False.
|
to analysis. Defaults to False.
|
||||||
|
model_names (list, optional): Provide model names for summary, sentiment and ner
|
||||||
|
analysis. Defaults to None, in which case the default model from transformers
|
||||||
|
are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary),
|
||||||
|
"distilbert-base-uncased-finetuned-sst-2-english" (sentiment),
|
||||||
|
"dbmdz/bert-large-cased-finetuned-conll03-english".
|
||||||
|
To select other models, provide a list with three entries, the first for
|
||||||
|
summary, second for sentiment, third for NER, with the desired model names.
|
||||||
|
Set one of these to None to still use the default model.
|
||||||
|
revision_numbers (list, optional): Model revision (commit) numbers on the
|
||||||
|
Hugging Face hub. Provide this to make sure you are using the same model.
|
||||||
|
Defaults to None, except if the default models are used; then it defaults to
|
||||||
|
"a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert),
|
||||||
|
"f2482bf" (NER, bert).
|
||||||
"""
|
"""
|
||||||
super().__init__(subdict)
|
super().__init__(subdict)
|
||||||
self.subdict.update(self.set_keys())
|
self.subdict.update(self.set_keys())
|
||||||
self.translator = Translator()
|
self.translator = Translator()
|
||||||
|
if not isinstance(analyse_text, bool):
|
||||||
|
raise ValueError("analyse_text needs to be set to true or false")
|
||||||
self.analyse_text = analyse_text
|
self.analyse_text = analyse_text
|
||||||
|
if model_names:
|
||||||
|
self._check_valid_models(model_names)
|
||||||
|
if revision_numbers:
|
||||||
|
self._check_revision_numbers(revision_numbers)
|
||||||
|
# initialize revision numbers and models
|
||||||
|
self._init_revision_numbers(model_names, revision_numbers)
|
||||||
|
self._init_model(model_names)
|
||||||
|
|
||||||
|
def _check_valid_models(self, model_names):
|
||||||
|
# check that model_names and revision_numbers are valid lists or None
|
||||||
|
# check that model names are a list
|
||||||
|
if not isinstance(model_names, list):
|
||||||
|
raise ValueError("Model names need to be provided as a list!")
|
||||||
|
# check that enough models are provided, one for each method
|
||||||
|
if len(model_names) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Not enough or too many model names provided - three are required, one each for summary, sentiment, ner"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_revision_numbers(self, revision_numbers):
|
||||||
|
# check that revision numbers are list
|
||||||
|
if not isinstance(revision_numbers, list):
|
||||||
|
raise ValueError("Revision numbers need to be provided as a list!")
|
||||||
|
# check that three revision numbers are provided, one for each method
|
||||||
|
if len(revision_numbers) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_revision_numbers(self, model_names, revision_numbers):
|
||||||
|
"""Helper method to set the revision (version) number for each model."""
|
||||||
|
revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"]
|
||||||
|
if model_names:
|
||||||
|
# if model_names is provided, set revision numbers for each of the methods
|
||||||
|
# either as the provided revision number or None or as the default revision number,
|
||||||
|
# if one of the methods uses the default model
|
||||||
|
self._init_revision_numbers_per_model(
|
||||||
|
model_names, revision_numbers, revision_numbers_default
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# model_names was not provided, revision numbers are the default revision numbers or None
|
||||||
|
self.revision_summary = revision_numbers_default[0]
|
||||||
|
self.revision_sentiment = revision_numbers_default[1]
|
||||||
|
self.revision_ner = revision_numbers_default[2]
|
||||||
|
|
||||||
|
def _init_revision_numbers_per_model(
|
||||||
|
self, model_names, revision_numbers, revision_numbers_default
|
||||||
|
):
|
||||||
|
task_list = []
|
||||||
|
if not revision_numbers:
|
||||||
|
# no revision numbers for non-default models provided
|
||||||
|
revision_numbers = [None, None, None]
|
||||||
|
for model, revision, revision_default in zip(
|
||||||
|
model_names, revision_numbers, revision_numbers_default
|
||||||
|
):
|
||||||
|
# a model was specified for this task, set specified revision number or None
|
||||||
|
# or: model for this task was set to None, so we take default version number for default model
|
||||||
|
task_list.append(revision if model else revision_default)
|
||||||
|
self.revision_summary = task_list[0]
|
||||||
|
self.revision_sentiment = task_list[1]
|
||||||
|
self.revision_ner = task_list[2]
|
||||||
|
|
||||||
|
def _init_model(self, model_names):
|
||||||
|
"""Helper method to set the model name for each analysis method."""
|
||||||
|
# assign models for each of the text analysis methods
|
||||||
|
# and check that they are valid
|
||||||
|
model_names_default = [
|
||||||
|
"sshleifer/distilbart-cnn-12-6",
|
||||||
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||||
|
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
|
]
|
||||||
|
# no model names provided, set the default
|
||||||
|
if not model_names:
|
||||||
|
model_names = model_names_default
|
||||||
|
# now assign model names for each of the methods
|
||||||
|
# either to the provided model name or the default if one of the
|
||||||
|
# task's models is set to None
|
||||||
|
self.model_summary = (
|
||||||
|
model_names[0] if model_names[0] else model_names_default[0]
|
||||||
|
)
|
||||||
|
self.model_sentiment = (
|
||||||
|
model_names[1] if model_names[1] else model_names_default[1]
|
||||||
|
)
|
||||||
|
self.model_ner = model_names[2] if model_names[2] else model_names_default[2]
|
||||||
|
|
||||||
def set_keys(self) -> dict:
|
def set_keys(self) -> dict:
|
||||||
"""Set the default keys for text analysis.
|
"""Set the default keys for text analysis.
|
||||||
@ -99,13 +204,11 @@ class TextDetector(AnalysisMethod):
|
|||||||
"""Generate a summary of the text using the Transformers pipeline."""
|
"""Generate a summary of the text using the Transformers pipeline."""
|
||||||
# use the transformers pipeline to summarize the text
|
# use the transformers pipeline to summarize the text
|
||||||
# use the current default model - 03/2023
|
# use the current default model - 03/2023
|
||||||
model_name = "sshleifer/distilbart-cnn-12-6"
|
|
||||||
model_revision = "a4f8f3e"
|
|
||||||
max_number_of_characters = 3000
|
max_number_of_characters = 3000
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
"summarization",
|
"summarization",
|
||||||
model=model_name,
|
model=self.model_summary,
|
||||||
revision=model_revision,
|
revision=self.revision_summary,
|
||||||
min_length=5,
|
min_length=5,
|
||||||
max_length=20,
|
max_length=20,
|
||||||
)
|
)
|
||||||
@ -123,12 +226,10 @@ class TextDetector(AnalysisMethod):
|
|||||||
"""Perform text classification for sentiment using the Transformers pipeline."""
|
"""Perform text classification for sentiment using the Transformers pipeline."""
|
||||||
# use the transformers pipeline for text classification
|
# use the transformers pipeline for text classification
|
||||||
# use the current default model - 03/2023
|
# use the current default model - 03/2023
|
||||||
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
model_revision = "af0f99b"
|
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
"text-classification",
|
"text-classification",
|
||||||
model=model_name,
|
model=self.model_sentiment,
|
||||||
revision=model_revision,
|
revision=self.revision_sentiment,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
result = pipe(self.subdict["text_english"])
|
result = pipe(self.subdict["text_english"])
|
||||||
@ -139,12 +240,10 @@ class TextDetector(AnalysisMethod):
|
|||||||
"""Perform named entity recognition on the text using the Transformers pipeline."""
|
"""Perform named entity recognition on the text using the Transformers pipeline."""
|
||||||
# use the transformers pipeline for named entity recognition
|
# use the transformers pipeline for named entity recognition
|
||||||
# use the current default model - 03/2023
|
# use the current default model - 03/2023
|
||||||
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
|
||||||
model_revision = "f2482bf"
|
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
"token-classification",
|
"token-classification",
|
||||||
model=model_name,
|
model=self.model_ner,
|
||||||
revision=model_revision,
|
revision=self.revision_ner,
|
||||||
aggregation_strategy="simple",
|
aggregation_strategy="simple",
|
||||||
)
|
)
|
||||||
result = pipe(self.subdict["text_english"])
|
result = pipe(self.subdict["text_english"])
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user