From 44e5a987b1fc2b916996cfa9a6b8c68c0bc85cd7 Mon Sep 17 00:00:00 2001 From: GwydionJon <35529247+GwydionJon@users.noreply.github.com> Date: Wed, 14 Jun 2023 22:17:20 +0200 Subject: [PATCH] Improve documentation (#89) * updated documentation in cropposts * updated documentation in display * updated documentation in faces * added comments to objects.py * updated utils.py docs * updated text.py docs * improve doc display * fix doc for display and remove redundant variable * removed documentation from cropposts.py * removed unused imports * get rid of ipywidgets dependency * remove unused imports, improve type hints * improve doc in utils * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Inga Ulusoy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ammico/cropposts.py | 9 ++- ammico/display.py | 111 +++++++++++++++++++++++------------ ammico/faces.py | 113 ++++++++++++++++++++++-------------- ammico/objects.py | 23 ++++++-- ammico/summary.py | 2 +- ammico/test/test_display.py | 4 +- ammico/test/test_faces.py | 8 +++ ammico/text.py | 91 +++++++++++++++++++++-------- ammico/utils.py | 38 +++++++----- pyproject.toml | 1 - 10 files changed, 270 insertions(+), 130 deletions(-) diff --git a/ammico/cropposts.py b/ammico/cropposts.py index e739a67..e85f2a0 100644 --- a/ammico/cropposts.py +++ b/ammico/cropposts.py @@ -84,6 +84,8 @@ def matching_points( sift = cv2.SIFT_create() kp1, des1 = sift.detectAndCompute(img1, None) kp2, des2 = sift.detectAndCompute(img2, None) + + # Convert descriptors to float32 des1 = np.float32(des1) des2 = np.float32(des2) # Initialize and use FLANN @@ -93,6 +95,7 @@ def matching_points( matches = flann.knnMatch(des1, des2, k=2) filtered_matches = [] for m, n in matches: + # Apply ratio test to filter out ambiguous matches if m.distance < 0.7 * n.distance: filtered_matches.append(m) return filtered_matches, kp1, kp2 @@ -141,6 +144,8 @@ def compute_crop_corner( kp1, kp2 = kp_from_matches(matches, kp1, kp2) ys = kp2[:, 1] covers = [] + + # Compute the number of keypoints within the region around each y-coordinate for y in ys: ys_c = ys - y series = pd.Series(ys_c) @@ -151,7 +156,10 @@ def compute_crop_corner( return None kp_id = ys[covers.argmax()] v = int(kp_id) - v_margin if int(kp_id) > v_margin else int(kp_id) + hs = [] + + # Find the minimum x-coordinate within the region around the selected y-coordinate for kp in kp2: if 0 <= kp[1] - v <= region: hs.append(kp[0]) @@ -320,7 +328,6 @@ def crop_media_posts( for ref_file in ref_files: ref_view = cv2.imread(ref_file) ref_views.append(ref_view) - # parse through the social media posts to be cropped for crop_file in files: view = cv2.imread(crop_file) diff --git a/ammico/display.py b/ammico/display.py index a3d0e33..d18d549 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -1,34 +1,26 @@ -from IPython.display import display - import ammico.faces as faces import ammico.text as text import ammico.objects as objects from ammico.utils import is_interactive - import ammico.summary as summary - import dash_renderjson from dash import html, Input, Output, dcc, State import jupyter_dash from PIL import Image -class JSONContainer: - """Expose a Python dictionary as a JSON document in JupyterLab - rich display rendering. - """ - - def __init__(self, data=None): - if data is None: - data = {} - self._data = data - - def _repr_json_(self): - return self._data - - class AnalysisExplorer: - def __init__(self, mydict, identify="faces") -> None: + def __init__(self, mydict: dict, identify: str = "faces") -> None: + """Initialize the AnalysisExplorer class to create an interactive + visualization of the analysis results. + + Args: + mydict (dict): A nested dictionary containing image data for all images. + identify (str, optional): The type of analysis to perform (default: "faces"). + Options are "faces" (face and emotion detection), "text-on-image" (image + extraction and analysis), "objects" (object detection), "summary" (image caption + generation). + """ self.app = jupyter_dash.JupyterDash(__name__) self.mydict = mydict self.identify = identify @@ -53,19 +45,18 @@ class AnalysisExplorer: "base0F": "#cc6633", } - # setup the layout + # Setup the layout app_layout = html.Div( [ - # top + # Top html.Div( ["Identify: ", identify, self._top_file_explorer(mydict)], id="Div_top", style={ "width": "30%", - # "display": "inline-block", }, ), - # middle + # Middle html.Div( [self._middle_picture_frame()], id="Div_middle", @@ -75,7 +66,7 @@ class AnalysisExplorer: "verticalAlign": "top", }, ), - # right + # Right html.Div( [self._right_output_json()], id="Div_right", @@ -89,7 +80,8 @@ class AnalysisExplorer: style={"width": "95%", "display": "inline-block"}, ) self.app.layout = app_layout - # add callbacks to app + + # Add callbacks to the app self.app.callback( Output("img_middle_picture_id", "src"), Input("left_select_id", "value"), @@ -105,8 +97,15 @@ class AnalysisExplorer: )(self._right_output_analysis) # I split the different sections into subfunctions for better clarity - def _top_file_explorer(self, mydict): - # initilizes the dropdown that selects which file is to be analyzed. + def _top_file_explorer(self, mydict: dict) -> html.Div: + """Initialize the file explorer dropdown for selecting the file to be analyzed. + + Args: + mydict (dict): A dictionary containing image data. + + Returns: + html.Div: The layout for the file explorer dropdown. + """ left_layout = html.Div( [ dcc.Dropdown( @@ -117,8 +116,12 @@ class AnalysisExplorer: ) return left_layout - def _middle_picture_frame(self): - # This just holds the image + def _middle_picture_frame(self) -> html.Div: + """Initialize the picture frame to display the image. + + Returns: + html.Div: The layout for the picture frame. + """ middle_layout = html.Div( [ html.Img( @@ -131,8 +134,12 @@ class AnalysisExplorer: ) return middle_layout - def _right_output_json(self): - # provides the json viewer for the analysis output. + def _right_output_json(self) -> html.Div: + """Initialize the JSON viewer for displaying the analysis output. + + Returns: + html.Div: The layout for the JSON viewer. + """ right_layout = html.Div( [ dcc.Loading( @@ -156,31 +163,61 @@ class AnalysisExplorer: ) return right_layout - def run_server(self, port=8050): + def run_server(self, port: int = 8050) -> None: + """Run the Dash server to start the analysis explorer. + + This method should only be called in an interactive environment like Jupyter notebooks. + Raises an EnvironmentError if not called in an interactive environment. + + Args: + port (int, optional): The port number to run the server on (default: 8050). + """ if not is_interactive(): raise EnvironmentError( - "Dash server should only be called in interactive an interactive environment like jupyter notebooks." + "Dash server should only be called in an interactive environment like Jupyter notebooks." ) self.app.run_server(debug=True, mode="inline", port=port) # Dash callbacks - def update_picture(self, img_path): + def update_picture(self, img_path: str): + """Callback function to update the displayed image. + + Args: + img_path (str): The path of the selected image. + + Returns: + Union[PIL.PngImagePlugin, None]: The image object to be displayed + or None if the image path is + + Note: + - This function is called when the value of the file explorer dropdown changes. + - Reads the image file and returns the image object. + """ if img_path is not None: image = Image.open(img_path) return image else: return None - def _right_output_analysis(self, image, all_options, current_value): - # calls the analysis function and returns the output + def _right_output_analysis(self, all_options: dict, current_value: str) -> dict: + """Callback function to perform analysis on the selected image and return the output. + + Args: + all_options (dict): The available options in the file explorer dropdown. + current_value (str): The current selected value in the file explorer dropdown. + + Returns: + dict: The analysis output for the selected image. + """ identify_dict = { "faces": faces.EmotionDetector, "text-on-image": text.TextDetector, "objects": objects.ObjectDetector, "summary": summary.SummaryDetector, } - # get image ID from dropdown value, which is the filepath. + + # Get image ID from dropdown value, which is the filepath image_id = all_options[current_value] identify_function = identify_dict[self.identify] diff --git a/ammico/faces.py b/ammico/faces.py index 2c43a07..3db0474 100644 --- a/ammico/faces.py +++ b/ammico/faces.py @@ -3,17 +3,15 @@ import numpy as np import os import shutil import pathlib -import ipywidgets - from tensorflow.keras.models import load_model from tensorflow.keras.applications.mobilenet_v2 import preprocess_input from tensorflow.keras.preprocessing.image import img_to_array from deepface import DeepFace from retinaface import RetinaFace - from ammico.utils import DownloadResource import ammico.utils as utils + DEEPFACE_PATH = ".deepface" @@ -85,8 +83,19 @@ retinaface_model = DownloadResource( class EmotionDetector(utils.AnalysisMethod): def __init__( - self, subdict: dict, emotion_threshold=50.0, race_threshold=50.0 + self, + subdict: dict, + emotion_threshold: float = 50.0, + race_threshold: float = 50.0, ) -> None: + """ + Initializes the EmotionDetector object. + + Args: + subdict (dict): The dictionary to store the analysis results. + emotion_threshold (float): The threshold for detecting emotions (default: 50.0). + race_threshold (float): The threshold for detecting race (default: 50.0). + """ super().__init__(subdict) self.subdict.update(self.set_keys()) self.emotion_threshold = emotion_threshold @@ -102,6 +111,12 @@ class EmotionDetector(utils.AnalysisMethod): } def set_keys(self) -> dict: + """ + Sets the initial parameters for the analysis. + + Returns: + dict: The dictionary with initial parameter values. + """ params = { "face": "No", "multiple_faces": "No", @@ -115,27 +130,38 @@ class EmotionDetector(utils.AnalysisMethod): } return params - def analyse_image(self): + def analyse_image(self) -> dict: + """ + Performs facial expression analysis on the image. + + Returns: + dict: The updated subdict dictionary with analysis results. + """ return self.facial_expression_analysis() def analyze_single_face(self, face: np.ndarray) -> dict: - fresult = {} + """ + Analyzes the features of a single face. + Args: + face (np.ndarray): The face image array. + + Returns: + dict: The analysis results for the face. + """ + fresult = {} # Determine whether the face wears a mask fresult["wears_mask"] = self.wears_mask(face) - - # Adapt the features we are looking for depending on whether a mask is - # worn. White masks screw race detection, emotion detection is useless. + # Adapt the features we are looking for depending on whether a mask is worn. + # White masks screw race detection, emotion detection is useless. actions = ["age", "gender"] if not fresult["wears_mask"]: actions = actions + ["race", "emotion"] - # Ensure that all data has been fetched by pooch deepface_age_model.get() deepface_face_expression_model.get() deepface_gender_model.get() deepface_race_model.get() - # Run the full DeepFace analysis fresult.update( DeepFace.analyze( @@ -145,25 +171,26 @@ class EmotionDetector(utils.AnalysisMethod): detector_backend="skip", ) ) - # We remove the region, as the data is not correct - after all we are # running the analysis on a subimage. del fresult["region"] - return fresult def facial_expression_analysis(self) -> dict: + """ + Performs facial expression analysis on the image. + + Returns: + dict: The updated subdict dictionary with analysis results. + """ # Find (multiple) faces in the image and cut them retinaface_model.get() faces = RetinaFace.extract_faces(self.subdict["filename"]) - # If no faces are found, we return empty keys if len(faces) == 0: return self.subdict - # Sort the faces by sight to prioritize prominent faces faces = list(reversed(sorted(faces, key=lambda f: f.shape[0] * f.shape[1]))) - self.subdict["face"] = "Yes" self.subdict["multiple_faces"] = "Yes" if len(faces) > 1 else "No" self.subdict["no_faces"] = len(faces) if len(faces) <= 15 else 99 @@ -172,13 +199,19 @@ class EmotionDetector(utils.AnalysisMethod): # We limit ourselves to three faces for i, face in enumerate(faces[:3]): result[f"person{ i+1 }"] = self.analyze_single_face(face) - self.clean_subdict(result) - return self.subdict def clean_subdict(self, result: dict) -> dict: - # each person subdict converted into list for keys + """ + Cleans the subdict dictionary by converting results into appropriate formats. + + Args: + result (dict): The analysis results. + Returns: + dict: The updated subdict dictionary. + """ + # Each person subdict converted into list for keys self.subdict["wears_mask"] = [] self.subdict["age"] = [] self.subdict["gender"] = [] @@ -191,12 +224,12 @@ class EmotionDetector(utils.AnalysisMethod): "Yes" if result[person]["wears_mask"] else "No" ) self.subdict["age"].append(result[person]["age"]) - # gender is now reported as a list of dictionaries - # each dict represents one face - # each dict contains probability for Woman and Man - # take only the higher prob result for each dict + # Gender is now reported as a list of dictionaries. + # Each dict represents one face. + # Each dict contains probability for Woman and Man. + # We take only the higher probability result for each dict. self.subdict["gender"].append(result[person]["gender"]) - # race, emotion only detected if person does not wear mask + # Race and emotion are only detected if a person does not wear a mask if result[person]["wears_mask"]: self.subdict["race"].append(None) self.subdict["emotion"].append(None) @@ -223,36 +256,28 @@ class EmotionDetector(utils.AnalysisMethod): else: self.subdict["emotion"].append(None) self.subdict["emotion (category)"].append(None) - return self.subdict def wears_mask(self, face: np.ndarray) -> bool: - global mask_detection_model + """ + Determines whether a face wears a mask. - # Preprocess the face to match the assumptions of the face mask - # detection model + Args: + face (np.ndarray): The face image array. + + Returns: + bool: True if the face wears a mask, False otherwise. + """ + global mask_detection_model + # Preprocess the face to match the assumptions of the face mask detection model face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) face = cv2.resize(face, (224, 224)) face = img_to_array(face) face = preprocess_input(face) face = np.expand_dims(face, axis=0) - # Lazily load the model mask_detection_model = load_model(face_mask_model.get()) - - # Run the model (ignoring output) - with NocatchOutput(): - mask, without_mask = mask_detection_model.predict(face)[0] - + # Run the model + mask, without_mask = mask_detection_model.predict(face)[0] # Convert from np.bool_ to bool to later be able to serialize the result return bool(mask > without_mask) - - -class NocatchOutput(ipywidgets.Output): - """An output container that suppresses output, but not exceptions - - Taken from https://github.com/jupyter-widgets/ipywidgets/issues/3208#issuecomment-1070836153 - """ - - def __exit__(self, *args, **kwargs): - super().__exit__(*args, **kwargs) diff --git a/ammico/objects.py b/ammico/objects.py index f60e6ee..9bbd64d 100644 --- a/ammico/objects.py +++ b/ammico/objects.py @@ -5,20 +5,22 @@ from ammico.objects_cvlib import init_default_objects class ObjectDetectorClient(AnalysisMethod): def __init__(self): - # The detector is default to CVLib - # Here other libraries can be added + # The detector is set to CVLib by default self.detector = ObjectCVLib() def set_client_to_cvlib(self): + """Set the object detection client to use CVLib.""" self.detector = ObjectCVLib() def analyse_image(self, subdict=None): - """Localize objects in the local image. + """Localize objects in the given image. Args: - subdict: The dictionary for an image expression instance. - """ + subdict (dict): The dictionary for an image expression instance. + Returns: + dict: The updated dictionary with object detection results. + """ return self.detector.analyse_image(subdict) @@ -30,12 +32,23 @@ class ObjectDetector(AnalysisMethod): self.subdict.update(self.set_keys()) def set_keys(self): + """Set the default object keys for analysis. + + Returns: + dict: The dictionary with default object keys. + """ return init_default_objects() def analyse_image(self): + """Perform object detection on the image. + + Returns: + dict: The updated dictionary with object detection results. + """ self.subdict = ObjectDetector.od_client.analyse_image(self.subdict) return self.subdict @staticmethod def set_client_to_cvlib(): + """Set the object detection client to use CVLib.""" ObjectDetector.od_client.set_client_to_cvlib() diff --git a/ammico/summary.py b/ammico/summary.py index 750b6ed..50fc2ee 100644 --- a/ammico/summary.py +++ b/ammico/summary.py @@ -1,5 +1,5 @@ from ammico.utils import AnalysisMethod -from torch import device, cuda, no_grad +from torch import cuda, no_grad from PIL import Image from lavis.models import load_model_and_preprocess diff --git a/ammico/test/test_display.py b/ammico/test/test_display.py index 4d26d6d..4d3c1f1 100644 --- a/ammico/test/test_display.py +++ b/ammico/test/test_display.py @@ -49,8 +49,8 @@ def test_AnalysisExplorer(get_path): assert analysis_explorer_objects.update_picture(None) is None - analysis_explorer_faces._right_output_analysis(None, all_options_dict, path_img_1) - analysis_explorer_objects._right_output_analysis(None, all_options_dict, path_img_2) + analysis_explorer_faces._right_output_analysis(all_options_dict, path_img_1) + analysis_explorer_objects._right_output_analysis(all_options_dict, path_img_2) with pytest.raises(EnvironmentError): analysis_explorer_faces.run_server(port=8050) diff --git a/ammico/test/test_faces.py b/ammico/test/test_faces.py index bb5926d..66ddad3 100644 --- a/ammico/test/test_faces.py +++ b/ammico/test/test_faces.py @@ -2,6 +2,14 @@ import ammico.faces as fc import json +def test_set_keys(): + ed = fc.EmotionDetector({}) + assert ed.subdict["face"] == "No" + assert ed.subdict["multiple_faces"] == "No" + assert ed.subdict["wears_mask"] == ["No"] + assert ed.subdict["emotion"] == [None] + + def test_analyse_faces(get_path): mydict = { "filename": get_path + "IMG_2746.png", diff --git a/ammico/text.py b/ammico/text.py index bb599fa..cf43c17 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -19,6 +19,14 @@ import os class TextDetector(utils.AnalysisMethod): def __init__(self, subdict: dict, analyse_text: bool = False) -> None: + """Init text detection class. + + Args: + subdict (dict): Dictionary containing file name/path, and possibly previous + analysis results from other modules. + analyse_text (bool, optional): Decide if extracted text will be further subject + to analysis. Defaults to False. + """ super().__init__(subdict) self.subdict.update(self.set_keys()) self.translator = Translator() @@ -28,10 +36,16 @@ class TextDetector(utils.AnalysisMethod): self._initialize_textblob() def set_keys(self) -> dict: + """Set the default keys for text analysis. + + Returns: + dict: The dictionary with default text keys. + """ params = {"text": None, "text_language": None, "text_english": None} return params def _initialize_spacy(self): + """Initialize the Spacy library for text analysis.""" try: self.nlp = spacy.load("en_core_web_md") except Exception: @@ -40,12 +54,18 @@ class TextDetector(utils.AnalysisMethod): self.nlp.add_pipe("spacytextblob") def _initialize_textblob(self): + """Initialize the TextBlob library for text analysis.""" try: TextBlob("Here") except Exception: download_corpora.main() - def analyse_image(self): + def analyse_image(self) -> dict: + """Perform text extraction and analysis of the text. + + Returns: + dict: The updated dictionary with text analysis results. + """ self.get_text_from_image() self.translate_text() self.remove_linebreaks() @@ -60,7 +80,7 @@ class TextDetector(utils.AnalysisMethod): return self.subdict def get_text_from_image(self): - """Detects text on the image.""" + """Detect text on the image using Google Cloud Vision API.""" path = self.subdict["filename"] try: client = vision.ImageAnnotatorClient() @@ -92,6 +112,7 @@ class TextDetector(utils.AnalysisMethod): ) def translate_text(self): + """Translate the detected text to English using the Translator object.""" translated = self.translator.translate(self.subdict["text"]) self.subdict["text_language"] = translated.src self.subdict["text_english"] = translated.text @@ -105,7 +126,7 @@ class TextDetector(utils.AnalysisMethod): ) def _run_spacy(self): - """Generate spacy doc object.""" + """Generate Spacy doc object for further text analysis.""" self.doc = self.nlp(self.subdict["text_english"]) def clean_text(self): @@ -118,10 +139,12 @@ class TextDetector(utils.AnalysisMethod): self.subdict["text_clean"] = " ".join(templist).rstrip().lstrip() def correct_spelling(self): + """Correct the spelling of the English text using TextBlob.""" self.textblob = TextBlob(self.subdict["text_english"]) self.subdict["text_english_correct"] = str(self.textblob.correct()) def sentiment_analysis(self): + """Perform sentiment analysis on the text using SpacyTextBlob.""" # polarity is between [-1.0, 1.0] self.subdict["polarity"] = self.doc._.blob.polarity # subjectivity is a float within the range [0.0, 1.0] @@ -129,6 +152,7 @@ class TextDetector(utils.AnalysisMethod): self.subdict["subjectivity"] = self.doc._.blob.subjectivity def text_summary(self): + """Generate a summary of the text using the Transformers pipeline.""" # use the transformers pipeline to summarize the text # use the current default model - 03/2023 model_name = "sshleifer/distilbart-cnn-12-6" @@ -152,6 +176,7 @@ class TextDetector(utils.AnalysisMethod): self.subdict["text_summary"] = None def text_sentiment_transformers(self): + """Perform text classification for sentiment using the Transformers pipeline.""" # use the transformers pipeline for text classification # use the current default model - 03/2023 model_name = "distilbert-base-uncased-finetuned-sst-2-english" @@ -167,6 +192,7 @@ class TextDetector(utils.AnalysisMethod): self.subdict["sentiment_score"] = result[0]["score"] def text_ner(self): + """Perform named entity recognition on the text using the Transformers pipeline.""" # use the transformers pipeline for named entity recognition # use the current default model - 03/2023 model_name = "dbmdz/bert-large-cased-finetuned-conll03-english" @@ -193,6 +219,15 @@ class PostprocessText: csv_path: str = None, analyze_text: str = "text_english", ) -> None: + """ + Initializes the PostprocessText class that handles the topic analysis. + + Args: + mydict (dict, optional): Dictionary with textual data. Defaults to None. + use_csv (bool, optional): Flag indicating whether to use a CSV file. Defaults to False. + csv_path (str, optional): Path to the CSV file. Required if `use_csv` is True. Defaults to None. + analyze_text (str, optional): Key for the text field to analyze. Defaults to "text_english". + """ self.use_csv = use_csv if mydict: print("Reading data from dict.") @@ -209,8 +244,16 @@ class PostprocessText: `csv_path`." ) - def analyse_topic(self, return_topics: int = 3): - """Topic analysis using BERTopic.""" + def analyse_topic(self, return_topics: int = 3) -> tuple: + """ + Performs topic analysis using BERTopic. + + Args: + return_topics (int, optional): Number of topics to return. Defaults to 3. + + Returns: + tuple: A tuple containing the topic model, topic dataframe, and most frequent topics. + """ # load spacy pipeline nlp = spacy.load( "en_core_web_md", @@ -237,7 +280,16 @@ class PostprocessText: most_frequent_topics.append(self.topic_model.get_topic(i)) return self.topic_model, topic_df, most_frequent_topics - def get_text_dict(self, analyze_text): + def get_text_dict(self, analyze_text: str) -> list: + """ + Extracts text from the provided dictionary. + + Args: + analyze_text (str): Key for the text field to analyze. + + Returns: + list: A list of text extracted from the dictionary. + """ # use dict to put text_english or text_summary in list list_text_english = [] for key in self.mydict.keys(): @@ -251,7 +303,16 @@ class PostprocessText: list_text_english.append(self.mydict[key][analyze_text]) return list_text_english - def get_text_df(self, analyze_text): + def get_text_df(self, analyze_text: str) -> list: + """ + Extracts text from the provided dataframe. + + Args: + analyze_text (str): Column name for the text field to analyze. + + Returns: + list: A list of text extracted from the dataframe. + """ # use csv file to obtain dataframe and put text_english or text_summary in list # check that "text_english" or "text_summary" is there if analyze_text not in self.df: @@ -262,19 +323,3 @@ class PostprocessText: ) ) return self.df[analyze_text].tolist() - - -if __name__ == "__main__": - images = utils.find_files( - path="data/test-debug/101-200fullposts", - limit=110, - ) - # images = ["data/test-debug/101-200fullposts/100638_mya.png"] - print(images) - mydict = utils.initialize_dict(images) - os.environ[ - "GOOGLE_APPLICATION_CREDENTIALS" - ] = "data/misinformation-campaign-981aa55a3b13.json" - for key in mydict: - print(key) - mydict[key] = TextDetector(mydict[key], analyse_text=True).analyse_image() diff --git a/ammico/utils.py b/ammico/utils.py index d3a01f0..c88254c 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -5,7 +5,7 @@ import pooch class DownloadResource: - """A remote resource that needs on demand downloading + """A remote resource that needs on demand downloading. We use this as a wrapper to the pooch library. The wrapper registers each data file and allows prefetching through the CLI entry point @@ -33,7 +33,7 @@ def ammico_prefetch_models(): class AnalysisMethod: """Base class to be inherited by all analysis methods.""" - def __init__(self, subdict) -> None: + def __init__(self, subdict: dict) -> None: self.subdict = subdict # define keys that will be set by the analysis @@ -44,35 +44,40 @@ class AnalysisMethod: raise NotImplementedError() -def find_files(path=None, pattern="*.png", recursive=True, limit=20): +def find_files( + path: str = None, pattern: str = "*.png", recursive: bool = True, limit: int = 20 +) -> list: """Find image files on the file system. - :param path: - The base directory where we are looking for the images. Defaults + Args: + path (str, optional): The base directory where we are looking for the images. Defaults to None, which uses the XDG data directory if set or the current working directory otherwise. - :param pattern: - The naming pattern that the filename should match. Defaults to + pattern (str, optional): The naming pattern that the filename should match. Defaults to "*.png". Can be used to allow other patterns or to only include specific prefixes or suffixes. - :param recursive: - Whether to recurse into subdirectories. - :param limit: - The maximum number of images to be found. Defaults to 20. - To return all images, set to None. + recursive (bool, optional): Whether to recurse into subdirectories. Default is set to False. + limit (int, optional): The maximum number of images to be found. + Defaults to 20. To return all images, set to None. + + Returns: + list: A list with all filenames including the path. """ if path is None: path = os.environ.get("XDG_DATA_HOME", ".") - result = list(glob.glob(f"{path}/{pattern}", recursive=recursive)) - if limit is not None: result = result[:limit] - return result def initialize_dict(filelist: list) -> dict: + """Initialize the nested dictionary for all the found images. + + Args: + filelist (list): The list of files to be analyzed, including their paths. + Returns: + dict: The nested dictionary with all image ids and their paths.""" mydict = {} for img_path in filelist: id_ = os.path.splitext(os.path.basename(img_path))[0] @@ -81,7 +86,7 @@ def initialize_dict(filelist: list) -> dict: def append_data_to_dict(mydict: dict) -> dict: - """Append entries from list of dictionaries to keys in global dict.""" + """Append entries from nested dictionaries to keys in a global dict.""" # first initialize empty list for each key that is present outdict = {key: [] for key in list(mydict.values())[0].keys()} @@ -98,6 +103,7 @@ def dump_df(mydict: dict) -> DataFrame: def is_interactive(): + """Check if we are running in an interactive environment.""" import __main__ as main return not hasattr(main, "__file__") diff --git a/pyproject.toml b/pyproject.toml index d8a08ab..20441ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "grpcio", "importlib_metadata", "ipython", - "ipywidgets<8.0.5", "jupyter_dash", "matplotlib", "numpy<=1.23.4",