зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 05:26:05 +02:00
maintain: remove text analysis with transformers and topic analysis
Этот коммит содержится в:
родитель
8a20b7ef43
Коммит
e4b812a397
@ -55,15 +55,6 @@ def test_TextDetector(set_testdict, accepted):
|
|||||||
assert not test_obj.analyse_text
|
assert not test_obj.analyse_text
|
||||||
assert not test_obj.skip_extraction
|
assert not test_obj.skip_extraction
|
||||||
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
|
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
|
||||||
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
|
|
||||||
assert (
|
|
||||||
test_obj.model_sentiment
|
|
||||||
== "distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
)
|
|
||||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
|
||||||
assert test_obj.revision_summary == "a4f8f3e"
|
|
||||||
assert test_obj.revision_sentiment == "af0f99b"
|
|
||||||
assert test_obj.revision_ner == "f2482bf"
|
|
||||||
test_obj = tt.TextDetector(
|
test_obj = tt.TextDetector(
|
||||||
{}, analyse_text=True, skip_extraction=True, accept_privacy=accepted
|
{}, analyse_text=True, skip_extraction=True, accept_privacy=accepted
|
||||||
)
|
)
|
||||||
@ -97,50 +88,6 @@ def test_clean_text(set_testdict, accepted):
|
|||||||
assert test_obj.subdict["text_clean"] == result
|
assert test_obj.subdict["text_clean"] == result
|
||||||
|
|
||||||
|
|
||||||
def test_init_revision_numbers_and_models(accepted):
|
|
||||||
test_obj = tt.TextDetector({}, accept_privacy=accepted)
|
|
||||||
# check the default options
|
|
||||||
assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6"
|
|
||||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
|
||||||
assert test_obj.revision_summary == "a4f8f3e"
|
|
||||||
assert test_obj.revision_sentiment == "af0f99b"
|
|
||||||
assert test_obj.revision_ner == "f2482bf"
|
|
||||||
# provide non-default options
|
|
||||||
model_names = ["facebook/bart-large-cnn", None, None]
|
|
||||||
test_obj = tt.TextDetector({}, model_names=model_names, accept_privacy=accepted)
|
|
||||||
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
|
||||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
|
||||||
assert not test_obj.revision_summary
|
|
||||||
assert test_obj.revision_sentiment == "af0f99b"
|
|
||||||
assert test_obj.revision_ner == "f2482bf"
|
|
||||||
revision_numbers = ["3d22493", None, None]
|
|
||||||
test_obj = tt.TextDetector(
|
|
||||||
{},
|
|
||||||
model_names=model_names,
|
|
||||||
revision_numbers=revision_numbers,
|
|
||||||
accept_privacy=accepted,
|
|
||||||
)
|
|
||||||
assert test_obj.model_summary == "facebook/bart-large-cnn"
|
|
||||||
assert test_obj.model_sentiment == "distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english"
|
|
||||||
assert test_obj.revision_summary == "3d22493"
|
|
||||||
assert test_obj.revision_sentiment == "af0f99b"
|
|
||||||
assert test_obj.revision_ner == "f2482bf"
|
|
||||||
# now test the exceptions
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, model_names=1.0, accept_privacy=accepted)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, revision_numbers=1.0, accept_privacy=accepted)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, model_names=["something"], accept_privacy=accepted)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, revision_numbers=["something"], accept_privacy=accepted)
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_add_space_after_full_stop(accepted):
|
def test_check_add_space_after_full_stop(accepted):
|
||||||
test_obj = tt.TextDetector({}, accept_privacy=accepted)
|
test_obj = tt.TextDetector({}, accept_privacy=accepted)
|
||||||
test_obj.subdict["text"] = "I like cats. I like dogs."
|
test_obj.subdict["text"] = "I like cats. I like dogs."
|
||||||
@ -153,7 +100,6 @@ def test_check_add_space_after_full_stop(accepted):
|
|||||||
test_obj._check_add_space_after_full_stop()
|
test_obj._check_add_space_after_full_stop()
|
||||||
assert test_obj.subdict["text"] == "www. icanhascheezburger. com"
|
assert test_obj.subdict["text"] == "www. icanhascheezburger. com"
|
||||||
|
|
||||||
|
|
||||||
def test_truncate_text(accepted):
|
def test_truncate_text(accepted):
|
||||||
test_obj = tt.TextDetector({}, accept_privacy=accepted)
|
test_obj = tt.TextDetector({}, accept_privacy=accepted)
|
||||||
test_obj.subdict["text"] = "I like cats and dogs."
|
test_obj.subdict["text"] = "I like cats and dogs."
|
||||||
@ -165,7 +111,6 @@ def test_truncate_text(accepted):
|
|||||||
assert test_obj.subdict["text_truncated"] == 5000 * "m"
|
assert test_obj.subdict["text_truncated"] == 5000 * "m"
|
||||||
assert test_obj.subdict["text"] == 20000 * "m"
|
assert test_obj.subdict["text"] == 20000 * "m"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gcv
|
@pytest.mark.gcv
|
||||||
def test_analyse_image(set_testdict, set_environ, accepted):
|
def test_analyse_image(set_testdict, set_environ, accepted):
|
||||||
for item in set_testdict:
|
for item in set_testdict:
|
||||||
@ -222,36 +167,6 @@ def test_remove_linebreaks(accepted):
|
|||||||
assert test_obj.subdict["text_english"] == "This is another test."
|
assert test_obj.subdict["text_english"] == "This is another test."
|
||||||
|
|
||||||
|
|
||||||
def test_text_summary(get_path, accepted):
|
|
||||||
mydict = {}
|
|
||||||
test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted)
|
|
||||||
ref_file = get_path + "example_summary.txt"
|
|
||||||
with open(ref_file, "r", encoding="utf8") as file:
|
|
||||||
reference_text = file.read()
|
|
||||||
mydict["text_english"] = reference_text
|
|
||||||
test_obj.text_summary()
|
|
||||||
reference_summary = " I’m sorry, but I don’t want to be an emperor"
|
|
||||||
assert mydict["text_summary"] == reference_summary
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_sentiment_transformers(accepted):
|
|
||||||
mydict = {}
|
|
||||||
test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted)
|
|
||||||
mydict["text_english"] = "I am happy that the CI is working again."
|
|
||||||
test_obj.text_sentiment_transformers()
|
|
||||||
assert mydict["sentiment"] == "POSITIVE"
|
|
||||||
assert mydict["sentiment_score"] == pytest.approx(0.99, 0.02)
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_ner(accepted):
|
|
||||||
mydict = {}
|
|
||||||
test_obj = tt.TextDetector(mydict, analyse_text=True, accept_privacy=accepted)
|
|
||||||
mydict["text_english"] = "Bill Gates was born in Seattle."
|
|
||||||
test_obj.text_ner()
|
|
||||||
assert mydict["entity"] == ["Bill Gates", "Seattle"]
|
|
||||||
assert mydict["entity_type"] == ["PER", "LOC"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_csv_option(get_path):
|
def test_init_csv_option(get_path):
|
||||||
test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv")
|
test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv")
|
||||||
assert test_obj.csv_path == get_path + "test.csv"
|
assert test_obj.csv_path == get_path + "test.csv"
|
||||||
@ -295,39 +210,3 @@ def test_read_csv(get_path):
|
|||||||
test_obj.mydict.items(), ref_dict.items()
|
test_obj.mydict.items(), ref_dict.items()
|
||||||
):
|
):
|
||||||
assert value_test["text"] == value_ref["text"]
|
assert value_test["text"] == value_ref["text"]
|
||||||
|
|
||||||
|
|
||||||
def test_PostprocessText(set_testdict, get_path):
|
|
||||||
reference_dict = "THE ALGEBRAIC EIGENVALUE PROBLEM"
|
|
||||||
reference_df = "Mathematische Formelsammlung\nfür Ingenieure und Naturwissenschaftler\nMit zahlreichen Abbildungen und Rechenbeispielen\nund einer ausführlichen Integraltafel\n3., verbesserte Auflage"
|
|
||||||
img_numbers = ["IMG_3755", "IMG_3756", "IMG_3757"]
|
|
||||||
for image_ref in img_numbers:
|
|
||||||
ref_file = get_path + "text_" + image_ref + ".txt"
|
|
||||||
with open(ref_file, "r") as file:
|
|
||||||
reference_text = file.read()
|
|
||||||
set_testdict[image_ref]["text_english"] = reference_text
|
|
||||||
obj = tt.PostprocessText(mydict=set_testdict)
|
|
||||||
test_dict = obj.list_text_english[2].replace("\r", "")
|
|
||||||
assert test_dict == reference_dict
|
|
||||||
for key in set_testdict.keys():
|
|
||||||
set_testdict[key].pop("text_english")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.PostprocessText(mydict=set_testdict)
|
|
||||||
obj = tt.PostprocessText(use_csv=True, csv_path=get_path + "test_data_out.csv")
|
|
||||||
# make sure test works on windows where end-of-line character is \r\n
|
|
||||||
test_df = obj.list_text_english[0].replace("\r", "")
|
|
||||||
assert test_df == reference_df
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.PostprocessText(use_csv=True, csv_path=get_path + "test_data_out_nokey.csv")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.PostprocessText()
|
|
||||||
|
|
||||||
|
|
||||||
def test_analyse_topic(get_path):
|
|
||||||
_, topic_df, most_frequent_topics = tt.PostprocessText(
|
|
||||||
use_csv=True, csv_path=get_path + "topic_analysis_test.csv"
|
|
||||||
).analyse_topic()
|
|
||||||
# since this is not deterministic we cannot be sure we get the same result twice
|
|
||||||
assert len(topic_df) == 2
|
|
||||||
assert topic_df["Name"].iloc[0] == "0_the_feat_of_is"
|
|
||||||
assert most_frequent_topics[0][0][0] == "the"
|
|
||||||
|
|||||||
296
ammico/text.py
296
ammico/text.py
@ -8,8 +8,6 @@ import re
|
|||||||
from ammico.utils import AnalysisMethod
|
from ammico.utils import AnalysisMethod
|
||||||
import grpc
|
import grpc
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from bertopic import BERTopic
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
PRIVACY_STATEMENT = """The Text Detector uses Google Cloud Vision
|
PRIVACY_STATEMENT = """The Text Detector uses Google Cloud Vision
|
||||||
and Google Translate. Detailed information about how information
|
and Google Translate. Detailed information about how information
|
||||||
@ -71,8 +69,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
subdict: dict,
|
subdict: dict,
|
||||||
analyse_text: bool = False,
|
analyse_text: bool = False,
|
||||||
skip_extraction: bool = False,
|
skip_extraction: bool = False,
|
||||||
model_names: list = None,
|
|
||||||
revision_numbers: list = None,
|
|
||||||
accept_privacy: str = "PRIVACY_AMMICO",
|
accept_privacy: str = "PRIVACY_AMMICO",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Init text detection class.
|
"""Init text detection class.
|
||||||
@ -84,19 +80,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
to analysis. Defaults to False.
|
to analysis. Defaults to False.
|
||||||
skip_extraction (bool, optional): Decide if text will be extracted from images or
|
skip_extraction (bool, optional): Decide if text will be extracted from images or
|
||||||
is already provided via a csv. Defaults to False.
|
is already provided via a csv. Defaults to False.
|
||||||
model_names (list, optional): Provide model names for summary, sentiment and ner
|
|
||||||
analysis. Defaults to None, in which case the default model from transformers
|
|
||||||
are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary),
|
|
||||||
"distilbert-base-uncased-finetuned-sst-2-english" (sentiment),
|
|
||||||
"dbmdz/bert-large-cased-finetuned-conll03-english".
|
|
||||||
To select other models, provide a list with three entries, the first for
|
|
||||||
summary, second for sentiment, third for NER, with the desired model names.
|
|
||||||
Set one of these to None to still use the default model.
|
|
||||||
revision_numbers (list, optional): Model revision (commit) numbers on the
|
|
||||||
Hugging Face hub. Provide this to make sure you are using the same model.
|
|
||||||
Defaults to None, except if the default models are used; then it defaults to
|
|
||||||
"a4f8f3e" (summary, distilbart), "af0f99b" (sentiment, distilbert),
|
|
||||||
"f2482bf" (NER, bert).
|
|
||||||
accept_privacy (str, optional): Environment variable to accept the privacy
|
accept_privacy (str, optional): Environment variable to accept the privacy
|
||||||
statement for the Google Cloud processing of the data. Defaults to
|
statement for the Google Cloud processing of the data. Defaults to
|
||||||
"PRIVACY_AMMICO".
|
"PRIVACY_AMMICO".
|
||||||
@ -124,90 +107,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
print("Reading text directly from provided dictionary.")
|
print("Reading text directly from provided dictionary.")
|
||||||
if self.analyse_text:
|
if self.analyse_text:
|
||||||
self._initialize_spacy()
|
self._initialize_spacy()
|
||||||
if model_names:
|
|
||||||
self._check_valid_models(model_names)
|
|
||||||
if revision_numbers:
|
|
||||||
self._check_revision_numbers(revision_numbers)
|
|
||||||
# initialize revision numbers and models
|
|
||||||
self._init_revision_numbers(model_names, revision_numbers)
|
|
||||||
self._init_model(model_names)
|
|
||||||
|
|
||||||
def _check_valid_models(self, model_names):
|
|
||||||
# check that model_names and revision_numbers are valid lists or None
|
|
||||||
# check that model names are a list
|
|
||||||
if not isinstance(model_names, list):
|
|
||||||
raise ValueError("Model names need to be provided as a list!")
|
|
||||||
# check that enough models are provided, one for each method
|
|
||||||
if len(model_names) != 3:
|
|
||||||
raise ValueError(
|
|
||||||
"Not enough or too many model names provided - three are required, one each for summary, sentiment, ner"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_revision_numbers(self, revision_numbers):
|
|
||||||
# check that revision numbers are list
|
|
||||||
if not isinstance(revision_numbers, list):
|
|
||||||
raise ValueError("Revision numbers need to be provided as a list!")
|
|
||||||
# check that three revision numbers are provided, one for each method
|
|
||||||
if len(revision_numbers) != 3:
|
|
||||||
raise ValueError(
|
|
||||||
"Not enough or too many revision numbers provided - three are required, one each for summary, sentiment, ner"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_revision_numbers(self, model_names, revision_numbers):
|
|
||||||
"""Helper method to set the revision (version) number for each model."""
|
|
||||||
revision_numbers_default = ["a4f8f3e", "af0f99b", "f2482bf"]
|
|
||||||
if model_names:
|
|
||||||
# if model_names is provided, set revision numbers for each of the methods
|
|
||||||
# either as the provided revision number or None or as the default revision number,
|
|
||||||
# if one of the methods uses the default model
|
|
||||||
self._init_revision_numbers_per_model(
|
|
||||||
model_names, revision_numbers, revision_numbers_default
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# model_names was not provided, revision numbers are the default revision numbers or None
|
|
||||||
self.revision_summary = revision_numbers_default[0]
|
|
||||||
self.revision_sentiment = revision_numbers_default[1]
|
|
||||||
self.revision_ner = revision_numbers_default[2]
|
|
||||||
|
|
||||||
def _init_revision_numbers_per_model(
|
|
||||||
self, model_names, revision_numbers, revision_numbers_default
|
|
||||||
):
|
|
||||||
task_list = []
|
|
||||||
if not revision_numbers:
|
|
||||||
# no revision numbers for non-default models provided
|
|
||||||
revision_numbers = [None, None, None]
|
|
||||||
for model, revision, revision_default in zip(
|
|
||||||
model_names, revision_numbers, revision_numbers_default
|
|
||||||
):
|
|
||||||
# a model was specified for this task, set specified revision number or None
|
|
||||||
# or: model for this task was set to None, so we take default version number for default model
|
|
||||||
task_list.append(revision if model else revision_default)
|
|
||||||
self.revision_summary = task_list[0]
|
|
||||||
self.revision_sentiment = task_list[1]
|
|
||||||
self.revision_ner = task_list[2]
|
|
||||||
|
|
||||||
def _init_model(self, model_names):
|
|
||||||
"""Helper method to set the model name for each analysis method."""
|
|
||||||
# assign models for each of the text analysis methods
|
|
||||||
# and check that they are valid
|
|
||||||
model_names_default = [
|
|
||||||
"sshleifer/distilbart-cnn-12-6",
|
|
||||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
|
||||||
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
||||||
]
|
|
||||||
# no model names provided, set the default
|
|
||||||
if not model_names:
|
|
||||||
model_names = model_names_default
|
|
||||||
# now assign model names for each of the methods
|
|
||||||
# either to the provided model name or the default if one of the
|
|
||||||
# task's models is set to None
|
|
||||||
self.model_summary = (
|
|
||||||
model_names[0] if model_names[0] else model_names_default[0]
|
|
||||||
)
|
|
||||||
self.model_sentiment = (
|
|
||||||
model_names[1] if model_names[1] else model_names_default[1]
|
|
||||||
)
|
|
||||||
self.model_ner = model_names[2] if model_names[2] else model_names_default[2]
|
|
||||||
|
|
||||||
def set_keys(self) -> dict:
|
def set_keys(self) -> dict:
|
||||||
"""Set the default keys for text analysis.
|
"""Set the default keys for text analysis.
|
||||||
@ -285,10 +184,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
self.remove_linebreaks()
|
self.remove_linebreaks()
|
||||||
if self.analyse_text and self.subdict["text_english"]:
|
if self.analyse_text and self.subdict["text_english"]:
|
||||||
self._run_spacy()
|
self._run_spacy()
|
||||||
self.clean_text()
|
|
||||||
self.text_summary()
|
|
||||||
self.text_sentiment_transformers()
|
|
||||||
self.text_ner()
|
|
||||||
return self.subdict
|
return self.subdict
|
||||||
|
|
||||||
def get_text_from_image(self):
|
def get_text_from_image(self):
|
||||||
@ -362,73 +257,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
"""Generate Spacy doc object for further text analysis."""
|
"""Generate Spacy doc object for further text analysis."""
|
||||||
self.doc = self.nlp(self.subdict["text_english"])
|
self.doc = self.nlp(self.subdict["text_english"])
|
||||||
|
|
||||||
def clean_text(self):
|
|
||||||
"""Clean the text from unrecognized words and any numbers."""
|
|
||||||
templist = []
|
|
||||||
for token in self.doc:
|
|
||||||
(
|
|
||||||
templist.append(token.text)
|
|
||||||
if token.pos_ != "NUM" and token.has_vector
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.subdict["text_clean"] = " ".join(templist).rstrip().lstrip()
|
|
||||||
|
|
||||||
def text_summary(self):
|
|
||||||
"""Generate a summary of the text using the Transformers pipeline."""
|
|
||||||
# use the transformers pipeline to summarize the text
|
|
||||||
# use the current default model - 03/2023
|
|
||||||
max_number_of_characters = 3000
|
|
||||||
pipe = pipeline(
|
|
||||||
"summarization",
|
|
||||||
model=self.model_summary,
|
|
||||||
revision=self.revision_summary,
|
|
||||||
min_length=5,
|
|
||||||
max_length=20,
|
|
||||||
framework="pt",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
summary = pipe(self.subdict["text_english"][0:max_number_of_characters])
|
|
||||||
self.subdict["text_summary"] = summary[0]["summary_text"]
|
|
||||||
except IndexError:
|
|
||||||
print(
|
|
||||||
"Cannot provide summary for this object - please check that the text has been translated correctly."
|
|
||||||
)
|
|
||||||
print("Image: {}".format(self.subdict["filename"]))
|
|
||||||
self.subdict["text_summary"] = None
|
|
||||||
|
|
||||||
def text_sentiment_transformers(self):
|
|
||||||
"""Perform text classification for sentiment using the Transformers pipeline."""
|
|
||||||
# use the transformers pipeline for text classification
|
|
||||||
# use the current default model - 03/2023
|
|
||||||
pipe = pipeline(
|
|
||||||
"text-classification",
|
|
||||||
model=self.model_sentiment,
|
|
||||||
revision=self.revision_sentiment,
|
|
||||||
truncation=True,
|
|
||||||
framework="pt",
|
|
||||||
)
|
|
||||||
result = pipe(self.subdict["text_english"])
|
|
||||||
self.subdict["sentiment"] = result[0]["label"]
|
|
||||||
self.subdict["sentiment_score"] = round(result[0]["score"], 2)
|
|
||||||
|
|
||||||
def text_ner(self):
|
|
||||||
"""Perform named entity recognition on the text using the Transformers pipeline."""
|
|
||||||
# use the transformers pipeline for named entity recognition
|
|
||||||
# use the current default model - 03/2023
|
|
||||||
pipe = pipeline(
|
|
||||||
"token-classification",
|
|
||||||
model=self.model_ner,
|
|
||||||
revision=self.revision_ner,
|
|
||||||
aggregation_strategy="simple",
|
|
||||||
framework="pt",
|
|
||||||
)
|
|
||||||
result = pipe(self.subdict["text_english"])
|
|
||||||
self.subdict["entity"] = []
|
|
||||||
self.subdict["entity_type"] = []
|
|
||||||
for entity in result:
|
|
||||||
self.subdict["entity"].append(entity["word"])
|
|
||||||
self.subdict["entity_type"].append(entity["entity_group"])
|
|
||||||
|
|
||||||
|
|
||||||
class TextAnalyzer:
|
class TextAnalyzer:
|
||||||
"""Used to get text from a csv and then run the TextDetector on it."""
|
"""Used to get text from a csv and then run the TextDetector on it."""
|
||||||
@ -492,127 +320,3 @@ class TextAnalyzer:
|
|||||||
"filename": self.csv_path,
|
"filename": self.csv_path,
|
||||||
"text": text,
|
"text": text,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PostprocessText:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mydict: dict = None,
|
|
||||||
use_csv: bool = False,
|
|
||||||
csv_path: str = None,
|
|
||||||
analyze_text: str = "text_english",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initializes the PostprocessText class that handles the topic analysis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mydict (dict, optional): Dictionary with textual data. Defaults to None.
|
|
||||||
use_csv (bool, optional): Flag indicating whether to use a CSV file. Defaults to False.
|
|
||||||
csv_path (str, optional): Path to the CSV file. Required if `use_csv` is True. Defaults to None.
|
|
||||||
analyze_text (str, optional): Key for the text field to analyze. Defaults to "text_english".
|
|
||||||
"""
|
|
||||||
self.use_csv = use_csv
|
|
||||||
if mydict:
|
|
||||||
print("Reading data from dict.")
|
|
||||||
self.mydict = mydict
|
|
||||||
self.list_text_english = self.get_text_dict(analyze_text)
|
|
||||||
elif self.use_csv:
|
|
||||||
print("Reading data from df.")
|
|
||||||
self.df = pd.read_csv(csv_path, encoding="utf8")
|
|
||||||
self.list_text_english = self.get_text_df(analyze_text)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Please provide either dictionary with textual data or \
|
|
||||||
a csv file by setting `use_csv` to True and providing a \
|
|
||||||
`csv_path`."
|
|
||||||
)
|
|
||||||
# initialize spacy
|
|
||||||
self._initialize_spacy()
|
|
||||||
|
|
||||||
def _initialize_spacy(self):
|
|
||||||
try:
|
|
||||||
self.nlp = spacy.load(
|
|
||||||
"en_core_web_md",
|
|
||||||
exclude=["tagger", "parser", "ner", "attribute_ruler", "lemmatizer"],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
spacy.cli.download("en_core_web_md")
|
|
||||||
self.nlp = spacy.load(
|
|
||||||
"en_core_web_md",
|
|
||||||
exclude=["tagger", "parser", "ner", "attribute_ruler", "lemmatizer"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def analyse_topic(self, return_topics: int = 3) -> tuple:
|
|
||||||
"""
|
|
||||||
Performs topic analysis using BERTopic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
return_topics (int, optional): Number of topics to return. Defaults to 3.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: A tuple containing the topic model, topic dataframe, and most frequent topics.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# unfortunately catching exceptions does not work here - need to figure out why
|
|
||||||
self.topic_model = BERTopic(embedding_model=self.nlp)
|
|
||||||
except TypeError:
|
|
||||||
print("BERTopic excited with an error - maybe your dataset is too small?")
|
|
||||||
self.topics, self.probs = self.topic_model.fit_transform(self.list_text_english)
|
|
||||||
# return the topic list
|
|
||||||
topic_df = self.topic_model.get_topic_info()
|
|
||||||
# return the most frequent return_topics
|
|
||||||
most_frequent_topics = []
|
|
||||||
if len(topic_df) < return_topics:
|
|
||||||
print("You requested more topics than are identified in your dataset -")
|
|
||||||
print(
|
|
||||||
"Returning only {} topics as these are all that have been found.".format(
|
|
||||||
len(topic_df)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for i in range(min(return_topics, len(topic_df))):
|
|
||||||
most_frequent_topics.append(self.topic_model.get_topic(i))
|
|
||||||
return self.topic_model, topic_df, most_frequent_topics
|
|
||||||
|
|
||||||
def get_text_dict(self, analyze_text: str) -> list:
|
|
||||||
"""
|
|
||||||
Extracts text from the provided dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
analyze_text (str): Key for the text field to analyze.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list of text extracted from the dictionary.
|
|
||||||
"""
|
|
||||||
# use dict to put text_english or text_summary in list
|
|
||||||
list_text_english = []
|
|
||||||
for key in self.mydict.keys():
|
|
||||||
if analyze_text not in self.mydict[key]:
|
|
||||||
raise ValueError(
|
|
||||||
"Please check your provided dictionary - \
|
|
||||||
no {} text data found.".format(
|
|
||||||
analyze_text
|
|
||||||
)
|
|
||||||
)
|
|
||||||
list_text_english.append(self.mydict[key][analyze_text])
|
|
||||||
return list_text_english
|
|
||||||
|
|
||||||
def get_text_df(self, analyze_text: str) -> list:
|
|
||||||
"""
|
|
||||||
Extracts text from the provided dataframe.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
analyze_text (str): Column name for the text field to analyze.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list of text extracted from the dataframe.
|
|
||||||
"""
|
|
||||||
# use csv file to obtain dataframe and put text_english or text_summary in list
|
|
||||||
# check that "text_english" or "text_summary" is there
|
|
||||||
if analyze_text not in self.df:
|
|
||||||
raise ValueError(
|
|
||||||
"Please check your provided dataframe - \
|
|
||||||
no {} text data found.".format(
|
|
||||||
analyze_text
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return self.df[analyze_text].tolist()
|
|
||||||
|
|||||||
@ -20,7 +20,6 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bertopic<=0.14.1",
|
|
||||||
"dash>=2.11.0",
|
"dash>=2.11.0",
|
||||||
"datasets",
|
"datasets",
|
||||||
"deepface<=0.0.93",
|
"deepface<=0.0.93",
|
||||||
@ -50,7 +49,6 @@ dependencies = [
|
|||||||
"spacy<=3.7.5",
|
"spacy<=3.7.5",
|
||||||
"tensorflow>=2.13.0",
|
"tensorflow>=2.13.0",
|
||||||
"torch<2.6.0",
|
"torch<2.6.0",
|
||||||
"transformers",
|
|
||||||
"google-cloud-vision",
|
"google-cloud-vision",
|
||||||
"dash_bootstrap_components",
|
"dash_bootstrap_components",
|
||||||
"colorgram.py",
|
"colorgram.py",
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user