зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
use simpler image for testing emotion detector (#190)
* use simpler image for testing * include age in faces test again * fix typo * try with newer tensorflow version * remove testing for age again * try with tensorflow newer versions only for breaking change in transformers * force transformers to use pytorch
Этот коммит содержится в:
родитель
922a64f991
Коммит
38498e3e10
@ -203,7 +203,7 @@ class EmotionDetector(AnalysisMethod):
|
|||||||
result = {"number_faces": len(faces) if len(faces) <= 3 else 3}
|
result = {"number_faces": len(faces) if len(faces) <= 3 else 3}
|
||||||
# We limit ourselves to three faces
|
# We limit ourselves to three faces
|
||||||
for i, face in enumerate(faces[:3]):
|
for i, face in enumerate(faces[:3]):
|
||||||
result[f"person{ i+1 }"] = self.analyze_single_face(face)
|
result[f"person{i+1}"] = self.analyze_single_face(face)
|
||||||
self.clean_subdict(result)
|
self.clean_subdict(result)
|
||||||
return self.subdict
|
return self.subdict
|
||||||
|
|
||||||
|
|||||||
@ -332,7 +332,8 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||||
r = requests.get(url, allow_redirects=False)
|
r = requests.get(url, allow_redirects=False)
|
||||||
open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb").write(r.content)
|
with open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb") as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
image_keys = sorted(self.subdict.keys())
|
image_keys = sorted(self.subdict.keys())
|
||||||
image_names = [self.subdict[k]["filename"] for k in image_keys]
|
image_names = [self.subdict[k]["filename"] for k in image_keys]
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
{
|
{
|
||||||
"face": "Yes",
|
"face": "Yes",
|
||||||
"multiple_faces": "Yes",
|
"multiple_faces": "No",
|
||||||
"no_faces": 11,
|
"no_faces": 1,
|
||||||
"wears_mask": ["No", "No", "Yes"],
|
"wears_mask": ["No"],
|
||||||
"gender": ["Man", "Man", "Man"],
|
"gender": ["Woman"],
|
||||||
"race": ["white", "white", null],
|
"race": ["asian"],
|
||||||
"emotion": ["sad", "fear", null],
|
"emotion": ["happy"],
|
||||||
"emotion (category)": ["Negative", "Negative", null]
|
"emotion (category)": ["Positive"]
|
||||||
}
|
}
|
||||||
Двоичные данные
ammico/test/data/pexels-pixabay-415829.jpg
Обычный файл
Двоичные данные
ammico/test/data/pexels-pixabay-415829.jpg
Обычный файл
Двоичный файл не отображается.
|
После Ширина: | Высота: | Размер: 1.2 MiB |
@ -21,7 +21,7 @@ def test_set_keys():
|
|||||||
|
|
||||||
def test_analyse_faces(get_path):
|
def test_analyse_faces(get_path):
|
||||||
mydict = {
|
mydict = {
|
||||||
"filename": get_path + "IMG_2746.png",
|
"filename": get_path + "pexels-pixabay-415829.jpg",
|
||||||
}
|
}
|
||||||
mydict.update(fc.EmotionDetector(mydict).analyse_image())
|
mydict.update(fc.EmotionDetector(mydict).analyse_image())
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ def test_analyse_faces(get_path):
|
|||||||
out_dict = json.load(file)
|
out_dict = json.load(file)
|
||||||
# delete the filename key
|
# delete the filename key
|
||||||
mydict.pop("filename", None)
|
mydict.pop("filename", None)
|
||||||
# delete the age key, as this is conflicting - gives different results sometimes
|
# do not test for age, as this is not a reliable metric
|
||||||
mydict.pop("age", None)
|
mydict.pop("age", None)
|
||||||
for key in mydict.keys():
|
for key in mydict.keys():
|
||||||
assert mydict[key] == out_dict[key]
|
assert mydict[key] == out_dict[key]
|
||||||
|
|||||||
@ -238,6 +238,7 @@ class TextDetector(AnalysisMethod):
|
|||||||
revision=self.revision_summary,
|
revision=self.revision_summary,
|
||||||
min_length=5,
|
min_length=5,
|
||||||
max_length=20,
|
max_length=20,
|
||||||
|
framework="pt",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
summary = pipe(self.subdict["text_english"][0:max_number_of_characters])
|
summary = pipe(self.subdict["text_english"][0:max_number_of_characters])
|
||||||
@ -258,6 +259,7 @@ class TextDetector(AnalysisMethod):
|
|||||||
model=self.model_sentiment,
|
model=self.model_sentiment,
|
||||||
revision=self.revision_sentiment,
|
revision=self.revision_sentiment,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
framework="pt",
|
||||||
)
|
)
|
||||||
result = pipe(self.subdict["text_english"])
|
result = pipe(self.subdict["text_english"])
|
||||||
self.subdict["sentiment"] = result[0]["label"]
|
self.subdict["sentiment"] = result[0]["label"]
|
||||||
@ -272,6 +274,7 @@ class TextDetector(AnalysisMethod):
|
|||||||
model=self.model_ner,
|
model=self.model_ner,
|
||||||
revision=self.revision_ner,
|
revision=self.revision_ner,
|
||||||
aggregation_strategy="simple",
|
aggregation_strategy="simple",
|
||||||
|
framework="pt",
|
||||||
)
|
)
|
||||||
result = pipe(self.subdict["text_english"])
|
result = pipe(self.subdict["text_english"])
|
||||||
self.subdict["entity"] = []
|
self.subdict["entity"] = []
|
||||||
@ -281,6 +284,58 @@ class TextDetector(AnalysisMethod):
|
|||||||
self.subdict["entity_type"].append(entity["entity_group"])
|
self.subdict["entity_type"].append(entity["entity_group"])
|
||||||
|
|
||||||
|
|
||||||
|
class TextAnalyzer:
|
||||||
|
"""Used to get text from a csv and then run the TextDetector on it."""
|
||||||
|
|
||||||
|
def __init__(self, csv_path: str, column_key: str = None) -> None:
|
||||||
|
"""Init the TextTranslator class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_path (str): Path to the CSV file containing the text entries.
|
||||||
|
column_key (str): Key for the column containing the text entries.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
self.csv_path = csv_path
|
||||||
|
self.column_key = column_key
|
||||||
|
self._check_valid_csv_path()
|
||||||
|
self._check_file_exists()
|
||||||
|
|
||||||
|
def _check_valid_csv_path(self):
|
||||||
|
if not isinstance(self.csv_path, str):
|
||||||
|
raise ValueError("The provided path to the CSV file is not a string.")
|
||||||
|
if not self.csv_path.endswith(".csv"):
|
||||||
|
raise ValueError("The provided file is not a CSV file.")
|
||||||
|
|
||||||
|
def _check_file_exists(self):
|
||||||
|
try:
|
||||||
|
with open(self.csv_path, "r") as file: # noqa
|
||||||
|
pass
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError("The provided CSV file does not exist.")
|
||||||
|
|
||||||
|
def read_csv(self) -> dict:
|
||||||
|
"""Read the CSV file and return the dictionary with the text entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The dictionary with the text entries.
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(self.csv_path, encoding="utf8")
|
||||||
|
if not self.column_key:
|
||||||
|
self.column_key = "text"
|
||||||
|
|
||||||
|
if self.column_key not in df:
|
||||||
|
raise ValueError(
|
||||||
|
"The provided column key is not in the CSV file. Please check."
|
||||||
|
)
|
||||||
|
self.mylist = df[self.column_key].to_list()
|
||||||
|
self.mydict = {}
|
||||||
|
for i, text in enumerate(self.mylist):
|
||||||
|
self.mydict[self.csv_path + "row-" + str(i)] = {
|
||||||
|
"filename": self.csv_path,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PostprocessText:
|
class PostprocessText:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -46,7 +46,7 @@ dependencies = [
|
|||||||
"ammico-lavis",
|
"ammico-lavis",
|
||||||
"setuptools",
|
"setuptools",
|
||||||
"spacy",
|
"spacy",
|
||||||
"tensorflow<=2.12.3",
|
"tensorflow>=2.13.0",
|
||||||
"torch<2.1.0",
|
"torch<2.1.0",
|
||||||
"transformers",
|
"transformers",
|
||||||
"google-cloud-vision",
|
"google-cloud-vision",
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user