From 4ac760e6909a800edc7217faea635479b248b074 Mon Sep 17 00:00:00 2001 From: Inga Ulusoy Date: Wed, 5 Jun 2024 09:28:28 +0200 Subject: [PATCH] Add text analyzer to skip text extraction from image (#199) * read in text from csv * add tests for csv reading * run textanalyzer in demo notebook * add text analyser in doc and demo * improve init TextDetector testing * more init tests * add csv encoding keyword * add utf16-csv file * skip csv reading on windows --- ammico/__init__.py | 3 +- ammico/data/ref/test.csv | 8 ++ ammico/notebooks/DemoNotebook_ammico.ipynb | 88 ++++++++++++++++ ammico/test/data/test-utf16.csv | Bin 0 -> 348 bytes ammico/test/data/test.csv | 8 ++ ammico/test/data/test_read_csv_ref.json | 32 ++++++ ammico/test/test_text.py | 71 +++++++++++-- ammico/text.py | 38 +++++-- .../notebooks/DemoNotebook_ammico.ipynb | 94 +++++++++++++++++- 9 files changed, 328 insertions(+), 14 deletions(-) create mode 100644 ammico/data/ref/test.csv create mode 100644 ammico/test/data/test-utf16.csv create mode 100644 ammico/test/data/test.csv create mode 100644 ammico/test/data/test_read_csv_ref.json diff --git a/ammico/__init__.py b/ammico/__init__.py index 7f1465f..3a78033 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -8,7 +8,7 @@ from ammico.display import AnalysisExplorer from ammico.faces import EmotionDetector from ammico.multimodal_search import MultimodalSearch from ammico.summary import SummaryDetector -from ammico.text import TextDetector, PostprocessText +from ammico.text import TextDetector, TextAnalyzer, PostprocessText from ammico.utils import find_files, get_dataframe # Export the version defined in project metadata @@ -23,6 +23,7 @@ __all__ = [ "MultimodalSearch", "SummaryDetector", "TextDetector", + "TextAnalyzer", "PostprocessText", "find_files", "get_dataframe", diff --git a/ammico/data/ref/test.csv b/ammico/data/ref/test.csv new file mode 100644 index 0000000..f73b9da --- /dev/null +++ b/ammico/data/ref/test.csv @@ -0,0 +1,8 @@ +text, date +this is a test, 05/31/24 +bu bir denemedir, 05/31/24 +dies ist ein Test, 05/31/24 +c'est un test, 05/31/24 +esto es una prueba, 05/31/24 +detta är ett test, 05/31/24 + diff --git a/ammico/notebooks/DemoNotebook_ammico.ipynb b/ammico/notebooks/DemoNotebook_ammico.ipynb index f3767c3..12bd6ea 100644 --- a/ammico/notebooks/DemoNotebook_ammico.ipynb +++ b/ammico/notebooks/DemoNotebook_ammico.ipynb @@ -366,6 +366,94 @@ "image_df.to_csv(\"/content/drive/MyDrive/misinformation-data/data_out.csv\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read in a csv file containing text and translating/analysing the text\n", + "\n", + "Instead of extracting text from an image, or to re-process text that was already extracted, it is also possible to provide a `csv` file containing text in its rows.\n", + "Provide the path and name of the csv file with the keyword `csv_path`. The keyword `column_key` tells the Analyzer which column key in the csv file holds the text that should be analyzed. This defaults to \"text\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ta = ammico.TextAnalyzer(csv_path=\"../data/ref/test.csv\", column_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read the csv file\n", + "ta.read_csv()\n", + "# set up the dict containing all text entries\n", + "text_dict = ta.mydict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set the dump file\n", + "# dump file name\n", + "dump_file = \"dump_file.csv\"\n", + "# dump every N images \n", + "dump_every = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# analyze the csv file\n", + "for num, key in tqdm(enumerate(text_dict.keys()), total=len(text_dict)): # loop through all text entries\n", + " ammico.TextDetector(text_dict[key], analyse_text=True, skip_extraction=True).analyse_image() # analyse text with TextDetector and update dict\n", + " if num % dump_every == 0 | num == len(text_dict) - 1: # save results every dump_every to dump_file\n", + " image_df = ammico.get_dataframe(text_dict)\n", + " image_df.to_csv(dump_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save the results to a csv file\n", + "text_df = ammico.get_dataframe(text_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# inspect\n", + "text_df.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# write to csv\n", + "text_df.to_csv(\"data_out.csv\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/ammico/test/data/test-utf16.csv b/ammico/test/data/test-utf16.csv new file mode 100644 index 0000000000000000000000000000000000000000..d88dec96c543008f7c8c00c4705f5e18886c70f5 GIT binary patch literal 348 zcmaivI}XA?3`Ad@Q>3Fo;(rq^kR_I&ARw}lI56iRjJ@e1pjfXzGvo32&KOTJOKMsg zWk#iV=0dMtiH$O!^b>N;ffYM8id*(#BnDxU@Q+Y3I^nN+ZS-7!&hnY7mOWX&snh8{ vM!)!)^G+ None: @@ -25,6 +26,8 @@ class TextDetector(AnalysisMethod): analysis results from other modules. analyse_text (bool, optional): Decide if extracted text will be further subject to analysis. Defaults to False. + skip_extraction (bool, optional): Decide if text will be extracted from images or + is already provided via a csv. Defaults to False. model_names (list, optional): Provide model names for summary, sentiment and ner analysis. Defaults to None, in which case the default model from transformers are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary), @@ -40,11 +43,21 @@ class TextDetector(AnalysisMethod): "f2482bf" (NER, bert). """ super().__init__(subdict) - self.subdict.update(self.set_keys()) + # disable this for now + # maybe it would be better to initialize the keys differently + # the reason is that they are inconsistent depending on the selected + # options, and also this may not be really necessary and rather restrictive + # self.subdict.update(self.set_keys()) self.translator = Translator() if not isinstance(analyse_text, bool): raise ValueError("analyse_text needs to be set to true or false") self.analyse_text = analyse_text + self.skip_extraction = skip_extraction + if not isinstance(skip_extraction, bool): + raise ValueError("skip_extraction needs to be set to true or false") + if self.skip_extraction: + print("Skipping text extraction from image.") + print("Reading text directly from provided dictionary.") if self.analyse_text: self._initialize_spacy() if model_names: @@ -155,7 +168,8 @@ class TextDetector(AnalysisMethod): Returns: dict: The updated dictionary with text analysis results. """ - self.get_text_from_image() + if not self.skip_extraction: + self.get_text_from_image() self.translate_text() self.remove_linebreaks() if self.analyse_text: @@ -287,18 +301,32 @@ class TextDetector(AnalysisMethod): class TextAnalyzer: """Used to get text from a csv and then run the TextDetector on it.""" - def __init__(self, csv_path: str, column_key: str = None) -> None: + def __init__( + self, csv_path: str, column_key: str = None, csv_encoding: str = "utf-8" + ) -> None: """Init the TextTranslator class. Args: csv_path (str): Path to the CSV file containing the text entries. column_key (str): Key for the column containing the text entries. Defaults to None. + csv_encoding (str): Encoding of the CSV file. Defaults to "utf-8". """ self.csv_path = csv_path self.column_key = column_key + self.csv_encoding = csv_encoding self._check_valid_csv_path() self._check_file_exists() + if not self.column_key: + print("No column key provided - using 'text' as default.") + self.column_key = "text" + if not self.csv_encoding: + print("No encoding provided - using 'utf-8' as default.") + self.csv_encoding = "utf-8" + if not isinstance(self.column_key, str): + raise ValueError("The provided column key is not a string.") + if not isinstance(self.csv_encoding, str): + raise ValueError("The provided encoding is not a string.") def _check_valid_csv_path(self): if not isinstance(self.csv_path, str): @@ -319,9 +347,7 @@ class TextAnalyzer: Returns: dict: The dictionary with the text entries. """ - df = pd.read_csv(self.csv_path, encoding="utf8") - if not self.column_key: - self.column_key = "text" + df = pd.read_csv(self.csv_path, encoding=self.csv_encoding) if self.column_key not in df: raise ValueError( diff --git a/docs/source/notebooks/DemoNotebook_ammico.ipynb b/docs/source/notebooks/DemoNotebook_ammico.ipynb index 9a0b06d..292a93d 100644 --- a/docs/source/notebooks/DemoNotebook_ammico.ipynb +++ b/docs/source/notebooks/DemoNotebook_ammico.ipynb @@ -94,7 +94,10 @@ "import os\n", "import ammico\n", "# for displaying a progress bar\n", - "from tqdm import tqdm" + "from tqdm import tqdm\n", + "# to get the reference data for text_dict\n", + "import importlib_resources\n", + "pkg = importlib_resources.files(\"ammico\")" ] }, { @@ -363,6 +366,95 @@ "image_df.to_csv(\"/content/drive/MyDrive/misinformation-data/data_out.csv\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read in a csv file containing text and translating/analysing the text\n", + "\n", + "Instead of extracting text from an image, or to re-process text that was already extracted, it is also possible to provide a `csv` file containing text in its rows.\n", + "Provide the path and name of the csv file with the keyword `csv_path`. The keyword `column_key` tells the Analyzer which column key in the csv file holds the text that should be analyzed. This defaults to \"text\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "csv_path = pkg / \"data\" / \"ref\" / \"test.csv\"\n", + "ta = ammico.TextAnalyzer(csv_path=str(csv_path), column_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read the csv file\n", + "ta.read_csv()\n", + "# set up the dict containing all text entries\n", + "text_dict = ta.mydict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set the dump file\n", + "# dump file name\n", + "dump_file = \"dump_file.csv\"\n", + "# dump every N images \n", + "dump_every = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# analyze the csv file\n", + "for num, key in tqdm(enumerate(text_dict.keys()), total=len(text_dict)): # loop through all text entries\n", + " ammico.TextDetector(text_dict[key], analyse_text=True, skip_extraction=True).analyse_image() # analyse text with TextDetector and update dict\n", + " if num % dump_every == 0 | num == len(text_dict) - 1: # save results every dump_every to dump_file\n", + " image_df = ammico.get_dataframe(text_dict)\n", + " image_df.to_csv(dump_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save the results to a csv file\n", + "text_df = ammico.get_dataframe(text_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# inspect\n", + "text_df.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# write to csv\n", + "text_df.to_csv(\"data_out.csv\")" + ] + }, { "cell_type": "markdown", "metadata": {},