зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-31 05:56:05 +02:00 
			
		
		
		
	Improve documentation (#89)
* updated documentation in cropposts * updated documentation in display * updated documentation in faces * added comments to objects.py * updated utils.py docs * updated text.py docs * improve doc display * fix doc for display and remove redundant variable * removed documentation from cropposts.py * removed unused imports * get rid of ipywidgets dependency * remove unused imports, improve type hints * improve doc in utils * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Inga Ulusoy <inga.ulusoy@uni-heidelberg.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Этот коммит содержится в:
		
							родитель
							
								
									4628692e95
								
							
						
					
					
						Коммит
						44e5a987b1
					
				| @ -84,6 +84,8 @@ def matching_points( | |||||||
|     sift = cv2.SIFT_create() |     sift = cv2.SIFT_create() | ||||||
|     kp1, des1 = sift.detectAndCompute(img1, None) |     kp1, des1 = sift.detectAndCompute(img1, None) | ||||||
|     kp2, des2 = sift.detectAndCompute(img2, None) |     kp2, des2 = sift.detectAndCompute(img2, None) | ||||||
|  | 
 | ||||||
|  |     # Convert descriptors to float32 | ||||||
|     des1 = np.float32(des1) |     des1 = np.float32(des1) | ||||||
|     des2 = np.float32(des2) |     des2 = np.float32(des2) | ||||||
|     # Initialize and use FLANN |     # Initialize and use FLANN | ||||||
| @ -93,6 +95,7 @@ def matching_points( | |||||||
|     matches = flann.knnMatch(des1, des2, k=2) |     matches = flann.knnMatch(des1, des2, k=2) | ||||||
|     filtered_matches = [] |     filtered_matches = [] | ||||||
|     for m, n in matches: |     for m, n in matches: | ||||||
|  |         # Apply ratio test to filter out ambiguous matches | ||||||
|         if m.distance < 0.7 * n.distance: |         if m.distance < 0.7 * n.distance: | ||||||
|             filtered_matches.append(m) |             filtered_matches.append(m) | ||||||
|     return filtered_matches, kp1, kp2 |     return filtered_matches, kp1, kp2 | ||||||
| @ -141,6 +144,8 @@ def compute_crop_corner( | |||||||
|     kp1, kp2 = kp_from_matches(matches, kp1, kp2) |     kp1, kp2 = kp_from_matches(matches, kp1, kp2) | ||||||
|     ys = kp2[:, 1] |     ys = kp2[:, 1] | ||||||
|     covers = [] |     covers = [] | ||||||
|  | 
 | ||||||
|  |     # Compute the number of keypoints within the region around each y-coordinate | ||||||
|     for y in ys: |     for y in ys: | ||||||
|         ys_c = ys - y |         ys_c = ys - y | ||||||
|         series = pd.Series(ys_c) |         series = pd.Series(ys_c) | ||||||
| @ -151,7 +156,10 @@ def compute_crop_corner( | |||||||
|         return None |         return None | ||||||
|     kp_id = ys[covers.argmax()] |     kp_id = ys[covers.argmax()] | ||||||
|     v = int(kp_id) - v_margin if int(kp_id) > v_margin else int(kp_id) |     v = int(kp_id) - v_margin if int(kp_id) > v_margin else int(kp_id) | ||||||
|  | 
 | ||||||
|     hs = [] |     hs = [] | ||||||
|  | 
 | ||||||
|  |     # Find the minimum x-coordinate within the region around the selected y-coordinate | ||||||
|     for kp in kp2: |     for kp in kp2: | ||||||
|         if 0 <= kp[1] - v <= region: |         if 0 <= kp[1] - v <= region: | ||||||
|             hs.append(kp[0]) |             hs.append(kp[0]) | ||||||
| @ -320,7 +328,6 @@ def crop_media_posts( | |||||||
|     for ref_file in ref_files: |     for ref_file in ref_files: | ||||||
|         ref_view = cv2.imread(ref_file) |         ref_view = cv2.imread(ref_file) | ||||||
|         ref_views.append(ref_view) |         ref_views.append(ref_view) | ||||||
| 
 |  | ||||||
|     # parse through the social media posts to be cropped |     # parse through the social media posts to be cropped | ||||||
|     for crop_file in files: |     for crop_file in files: | ||||||
|         view = cv2.imread(crop_file) |         view = cv2.imread(crop_file) | ||||||
|  | |||||||
| @ -1,34 +1,26 @@ | |||||||
| from IPython.display import display |  | ||||||
| 
 |  | ||||||
| import ammico.faces as faces | import ammico.faces as faces | ||||||
| import ammico.text as text | import ammico.text as text | ||||||
| import ammico.objects as objects | import ammico.objects as objects | ||||||
| from ammico.utils import is_interactive | from ammico.utils import is_interactive | ||||||
| 
 |  | ||||||
| import ammico.summary as summary | import ammico.summary as summary | ||||||
| 
 |  | ||||||
| import dash_renderjson | import dash_renderjson | ||||||
| from dash import html, Input, Output, dcc, State | from dash import html, Input, Output, dcc, State | ||||||
| import jupyter_dash | import jupyter_dash | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class JSONContainer: |  | ||||||
|     """Expose a Python dictionary as a JSON document in JupyterLab |  | ||||||
|     rich display rendering. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__(self, data=None): |  | ||||||
|         if data is None: |  | ||||||
|             data = {} |  | ||||||
|         self._data = data |  | ||||||
| 
 |  | ||||||
|     def _repr_json_(self): |  | ||||||
|         return self._data |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class AnalysisExplorer: | class AnalysisExplorer: | ||||||
|     def __init__(self, mydict, identify="faces") -> None: |     def __init__(self, mydict: dict, identify: str = "faces") -> None: | ||||||
|  |         """Initialize the AnalysisExplorer class to create an interactive | ||||||
|  |         visualization of the analysis results. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             mydict (dict): A nested dictionary containing image data for all images. | ||||||
|  |             identify (str, optional): The type of analysis to perform (default: "faces"). | ||||||
|  |             Options are "faces" (face and emotion detection), "text-on-image" (image | ||||||
|  |             extraction and analysis), "objects" (object detection), "summary" (image caption | ||||||
|  |             generation). | ||||||
|  |         """ | ||||||
|         self.app = jupyter_dash.JupyterDash(__name__) |         self.app = jupyter_dash.JupyterDash(__name__) | ||||||
|         self.mydict = mydict |         self.mydict = mydict | ||||||
|         self.identify = identify |         self.identify = identify | ||||||
| @ -53,19 +45,18 @@ class AnalysisExplorer: | |||||||
|             "base0F": "#cc6633", |             "base0F": "#cc6633", | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         # setup the layout |         #  Setup the layout | ||||||
|         app_layout = html.Div( |         app_layout = html.Div( | ||||||
|             [ |             [ | ||||||
|                 # top |                 # Top | ||||||
|                 html.Div( |                 html.Div( | ||||||
|                     ["Identify: ", identify, self._top_file_explorer(mydict)], |                     ["Identify: ", identify, self._top_file_explorer(mydict)], | ||||||
|                     id="Div_top", |                     id="Div_top", | ||||||
|                     style={ |                     style={ | ||||||
|                         "width": "30%", |                         "width": "30%", | ||||||
|                         # "display": "inline-block", |  | ||||||
|                     }, |                     }, | ||||||
|                 ), |                 ), | ||||||
|                 # middle |                 # Middle | ||||||
|                 html.Div( |                 html.Div( | ||||||
|                     [self._middle_picture_frame()], |                     [self._middle_picture_frame()], | ||||||
|                     id="Div_middle", |                     id="Div_middle", | ||||||
| @ -75,7 +66,7 @@ class AnalysisExplorer: | |||||||
|                         "verticalAlign": "top", |                         "verticalAlign": "top", | ||||||
|                     }, |                     }, | ||||||
|                 ), |                 ), | ||||||
|                 # right |                 # Right | ||||||
|                 html.Div( |                 html.Div( | ||||||
|                     [self._right_output_json()], |                     [self._right_output_json()], | ||||||
|                     id="Div_right", |                     id="Div_right", | ||||||
| @ -89,7 +80,8 @@ class AnalysisExplorer: | |||||||
|             style={"width": "95%", "display": "inline-block"}, |             style={"width": "95%", "display": "inline-block"}, | ||||||
|         ) |         ) | ||||||
|         self.app.layout = app_layout |         self.app.layout = app_layout | ||||||
|         # add callbacks to app | 
 | ||||||
|  |         # Add callbacks to the app | ||||||
|         self.app.callback( |         self.app.callback( | ||||||
|             Output("img_middle_picture_id", "src"), |             Output("img_middle_picture_id", "src"), | ||||||
|             Input("left_select_id", "value"), |             Input("left_select_id", "value"), | ||||||
| @ -105,8 +97,15 @@ class AnalysisExplorer: | |||||||
|         )(self._right_output_analysis) |         )(self._right_output_analysis) | ||||||
| 
 | 
 | ||||||
|     # I split the different sections into subfunctions for better clarity |     # I split the different sections into subfunctions for better clarity | ||||||
|     def _top_file_explorer(self, mydict): |     def _top_file_explorer(self, mydict: dict) -> html.Div: | ||||||
|         # initilizes the dropdown that selects which file is to be analyzed. |         """Initialize the file explorer dropdown for selecting the file to be analyzed. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             mydict (dict): A dictionary containing image data. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             html.Div: The layout for the file explorer dropdown. | ||||||
|  |         """ | ||||||
|         left_layout = html.Div( |         left_layout = html.Div( | ||||||
|             [ |             [ | ||||||
|                 dcc.Dropdown( |                 dcc.Dropdown( | ||||||
| @ -117,8 +116,12 @@ class AnalysisExplorer: | |||||||
|         ) |         ) | ||||||
|         return left_layout |         return left_layout | ||||||
| 
 | 
 | ||||||
|     def _middle_picture_frame(self): |     def _middle_picture_frame(self) -> html.Div: | ||||||
|         # This just holds the image |         """Initialize the picture frame to display the image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             html.Div: The layout for the picture frame. | ||||||
|  |         """ | ||||||
|         middle_layout = html.Div( |         middle_layout = html.Div( | ||||||
|             [ |             [ | ||||||
|                 html.Img( |                 html.Img( | ||||||
| @ -131,8 +134,12 @@ class AnalysisExplorer: | |||||||
|         ) |         ) | ||||||
|         return middle_layout |         return middle_layout | ||||||
| 
 | 
 | ||||||
|     def _right_output_json(self): |     def _right_output_json(self) -> html.Div: | ||||||
|         # provides the json viewer for the analysis output. |         """Initialize the JSON viewer for displaying the analysis output. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             html.Div: The layout for the JSON viewer. | ||||||
|  |         """ | ||||||
|         right_layout = html.Div( |         right_layout = html.Div( | ||||||
|             [ |             [ | ||||||
|                 dcc.Loading( |                 dcc.Loading( | ||||||
| @ -156,31 +163,61 @@ class AnalysisExplorer: | |||||||
|         ) |         ) | ||||||
|         return right_layout |         return right_layout | ||||||
| 
 | 
 | ||||||
|     def run_server(self, port=8050): |     def run_server(self, port: int = 8050) -> None: | ||||||
|  |         """Run the Dash server to start the analysis explorer. | ||||||
|  | 
 | ||||||
|  |         This method should only be called in an interactive environment like Jupyter notebooks. | ||||||
|  |         Raises an EnvironmentError if not called in an interactive environment. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             port (int, optional): The port number to run the server on (default: 8050). | ||||||
|  |         """ | ||||||
|         if not is_interactive(): |         if not is_interactive(): | ||||||
|             raise EnvironmentError( |             raise EnvironmentError( | ||||||
|                 "Dash server should only be called in interactive an interactive environment like jupyter notebooks." |                 "Dash server should only be called in an interactive environment like Jupyter notebooks." | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         self.app.run_server(debug=True, mode="inline", port=port) |         self.app.run_server(debug=True, mode="inline", port=port) | ||||||
| 
 | 
 | ||||||
|     # Dash callbacks |     # Dash callbacks | ||||||
|     def update_picture(self, img_path): |     def update_picture(self, img_path: str): | ||||||
|  |         """Callback function to update the displayed image. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             img_path (str): The path of the selected image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Union[PIL.PngImagePlugin, None]: The image object to be displayed | ||||||
|  |             or None if the image path is | ||||||
|  | 
 | ||||||
|  |         Note: | ||||||
|  |         - This function is called when the value of the file explorer dropdown changes. | ||||||
|  |         - Reads the image file and returns the image object. | ||||||
|  |         """ | ||||||
|         if img_path is not None: |         if img_path is not None: | ||||||
|             image = Image.open(img_path) |             image = Image.open(img_path) | ||||||
|             return image |             return image | ||||||
|         else: |         else: | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|     def _right_output_analysis(self, image, all_options, current_value): |     def _right_output_analysis(self, all_options: dict, current_value: str) -> dict: | ||||||
|         # calls the analysis function and returns the output |         """Callback function to perform analysis on the selected image and return the output. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             all_options (dict): The available options in the file explorer dropdown. | ||||||
|  |             current_value (str): The current selected value in the file explorer dropdown. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The analysis output for the selected image. | ||||||
|  |         """ | ||||||
|         identify_dict = { |         identify_dict = { | ||||||
|             "faces": faces.EmotionDetector, |             "faces": faces.EmotionDetector, | ||||||
|             "text-on-image": text.TextDetector, |             "text-on-image": text.TextDetector, | ||||||
|             "objects": objects.ObjectDetector, |             "objects": objects.ObjectDetector, | ||||||
|             "summary": summary.SummaryDetector, |             "summary": summary.SummaryDetector, | ||||||
|         } |         } | ||||||
|         # get image ID from dropdown value, which is the filepath. | 
 | ||||||
|  |         # Get image ID from dropdown value, which is the filepath | ||||||
|         image_id = all_options[current_value] |         image_id = all_options[current_value] | ||||||
| 
 | 
 | ||||||
|         identify_function = identify_dict[self.identify] |         identify_function = identify_dict[self.identify] | ||||||
|  | |||||||
							
								
								
									
										111
									
								
								ammico/faces.py
									
									
									
									
									
								
							
							
						
						
									
										111
									
								
								ammico/faces.py
									
									
									
									
									
								
							| @ -3,17 +3,15 @@ import numpy as np | |||||||
| import os | import os | ||||||
| import shutil | import shutil | ||||||
| import pathlib | import pathlib | ||||||
| import ipywidgets |  | ||||||
| 
 |  | ||||||
| from tensorflow.keras.models import load_model | from tensorflow.keras.models import load_model | ||||||
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | ||||||
| from tensorflow.keras.preprocessing.image import img_to_array | from tensorflow.keras.preprocessing.image import img_to_array | ||||||
| from deepface import DeepFace | from deepface import DeepFace | ||||||
| from retinaface import RetinaFace | from retinaface import RetinaFace | ||||||
| 
 |  | ||||||
| from ammico.utils import DownloadResource | from ammico.utils import DownloadResource | ||||||
| import ammico.utils as utils | import ammico.utils as utils | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| DEEPFACE_PATH = ".deepface" | DEEPFACE_PATH = ".deepface" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -85,8 +83,19 @@ retinaface_model = DownloadResource( | |||||||
| 
 | 
 | ||||||
| class EmotionDetector(utils.AnalysisMethod): | class EmotionDetector(utils.AnalysisMethod): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, subdict: dict, emotion_threshold=50.0, race_threshold=50.0 |         self, | ||||||
|  |         subdict: dict, | ||||||
|  |         emotion_threshold: float = 50.0, | ||||||
|  |         race_threshold: float = 50.0, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Initializes the EmotionDetector object. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             subdict (dict): The dictionary to store the analysis results. | ||||||
|  |             emotion_threshold (float): The threshold for detecting emotions (default: 50.0). | ||||||
|  |             race_threshold (float): The threshold for detecting race (default: 50.0). | ||||||
|  |         """ | ||||||
|         super().__init__(subdict) |         super().__init__(subdict) | ||||||
|         self.subdict.update(self.set_keys()) |         self.subdict.update(self.set_keys()) | ||||||
|         self.emotion_threshold = emotion_threshold |         self.emotion_threshold = emotion_threshold | ||||||
| @ -102,6 +111,12 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def set_keys(self) -> dict: |     def set_keys(self) -> dict: | ||||||
|  |         """ | ||||||
|  |         Sets the initial parameters for the analysis. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The dictionary with initial parameter values. | ||||||
|  |         """ | ||||||
|         params = { |         params = { | ||||||
|             "face": "No", |             "face": "No", | ||||||
|             "multiple_faces": "No", |             "multiple_faces": "No", | ||||||
| @ -115,27 +130,38 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|         } |         } | ||||||
|         return params |         return params | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self): |     def analyse_image(self) -> dict: | ||||||
|  |         """ | ||||||
|  |         Performs facial expression analysis on the image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated subdict dictionary with analysis results. | ||||||
|  |         """ | ||||||
|         return self.facial_expression_analysis() |         return self.facial_expression_analysis() | ||||||
| 
 | 
 | ||||||
|     def analyze_single_face(self, face: np.ndarray) -> dict: |     def analyze_single_face(self, face: np.ndarray) -> dict: | ||||||
|         fresult = {} |         """ | ||||||
|  |         Analyzes the features of a single face. | ||||||
| 
 | 
 | ||||||
|  |         Args: | ||||||
|  |             face (np.ndarray): The face image array. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The analysis results for the face. | ||||||
|  |         """ | ||||||
|  |         fresult = {} | ||||||
|         # Determine whether the face wears a mask |         # Determine whether the face wears a mask | ||||||
|         fresult["wears_mask"] = self.wears_mask(face) |         fresult["wears_mask"] = self.wears_mask(face) | ||||||
| 
 |         # Adapt the features we are looking for depending on whether a mask is worn. | ||||||
|         # Adapt the features we are looking for depending on whether a mask is |         # White masks screw race detection, emotion detection is useless. | ||||||
|         # worn. White masks screw race detection, emotion detection is useless. |  | ||||||
|         actions = ["age", "gender"] |         actions = ["age", "gender"] | ||||||
|         if not fresult["wears_mask"]: |         if not fresult["wears_mask"]: | ||||||
|             actions = actions + ["race", "emotion"] |             actions = actions + ["race", "emotion"] | ||||||
| 
 |  | ||||||
|         # Ensure that all data has been fetched by pooch |         # Ensure that all data has been fetched by pooch | ||||||
|         deepface_age_model.get() |         deepface_age_model.get() | ||||||
|         deepface_face_expression_model.get() |         deepface_face_expression_model.get() | ||||||
|         deepface_gender_model.get() |         deepface_gender_model.get() | ||||||
|         deepface_race_model.get() |         deepface_race_model.get() | ||||||
| 
 |  | ||||||
|         # Run the full DeepFace analysis |         # Run the full DeepFace analysis | ||||||
|         fresult.update( |         fresult.update( | ||||||
|             DeepFace.analyze( |             DeepFace.analyze( | ||||||
| @ -145,25 +171,26 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|                 detector_backend="skip", |                 detector_backend="skip", | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
| 
 |  | ||||||
|         # We remove the region, as the data is not correct - after all we are |         # We remove the region, as the data is not correct - after all we are | ||||||
|         # running the analysis on a subimage. |         # running the analysis on a subimage. | ||||||
|         del fresult["region"] |         del fresult["region"] | ||||||
| 
 |  | ||||||
|         return fresult |         return fresult | ||||||
| 
 | 
 | ||||||
|     def facial_expression_analysis(self) -> dict: |     def facial_expression_analysis(self) -> dict: | ||||||
|  |         """ | ||||||
|  |         Performs facial expression analysis on the image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated subdict dictionary with analysis results. | ||||||
|  |         """ | ||||||
|         # Find (multiple) faces in the image and cut them |         # Find (multiple) faces in the image and cut them | ||||||
|         retinaface_model.get() |         retinaface_model.get() | ||||||
|         faces = RetinaFace.extract_faces(self.subdict["filename"]) |         faces = RetinaFace.extract_faces(self.subdict["filename"]) | ||||||
| 
 |  | ||||||
|         # If no faces are found, we return empty keys |         # If no faces are found, we return empty keys | ||||||
|         if len(faces) == 0: |         if len(faces) == 0: | ||||||
|             return self.subdict |             return self.subdict | ||||||
| 
 |  | ||||||
|         # Sort the faces by sight to prioritize prominent faces |         # Sort the faces by sight to prioritize prominent faces | ||||||
|         faces = list(reversed(sorted(faces, key=lambda f: f.shape[0] * f.shape[1]))) |         faces = list(reversed(sorted(faces, key=lambda f: f.shape[0] * f.shape[1]))) | ||||||
| 
 |  | ||||||
|         self.subdict["face"] = "Yes" |         self.subdict["face"] = "Yes" | ||||||
|         self.subdict["multiple_faces"] = "Yes" if len(faces) > 1 else "No" |         self.subdict["multiple_faces"] = "Yes" if len(faces) > 1 else "No" | ||||||
|         self.subdict["no_faces"] = len(faces) if len(faces) <= 15 else 99 |         self.subdict["no_faces"] = len(faces) if len(faces) <= 15 else 99 | ||||||
| @ -172,13 +199,19 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|         # We limit ourselves to three faces |         # We limit ourselves to three faces | ||||||
|         for i, face in enumerate(faces[:3]): |         for i, face in enumerate(faces[:3]): | ||||||
|             result[f"person{ i+1 }"] = self.analyze_single_face(face) |             result[f"person{ i+1 }"] = self.analyze_single_face(face) | ||||||
| 
 |  | ||||||
|         self.clean_subdict(result) |         self.clean_subdict(result) | ||||||
| 
 |  | ||||||
|         return self.subdict |         return self.subdict | ||||||
| 
 | 
 | ||||||
|     def clean_subdict(self, result: dict) -> dict: |     def clean_subdict(self, result: dict) -> dict: | ||||||
|         # each person subdict converted into list for keys |         """ | ||||||
|  |         Cleans the subdict dictionary by converting results into appropriate formats. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             result (dict): The analysis results. | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated subdict dictionary. | ||||||
|  |         """ | ||||||
|  |         # Each person subdict converted into list for keys | ||||||
|         self.subdict["wears_mask"] = [] |         self.subdict["wears_mask"] = [] | ||||||
|         self.subdict["age"] = [] |         self.subdict["age"] = [] | ||||||
|         self.subdict["gender"] = [] |         self.subdict["gender"] = [] | ||||||
| @ -191,12 +224,12 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|                 "Yes" if result[person]["wears_mask"] else "No" |                 "Yes" if result[person]["wears_mask"] else "No" | ||||||
|             ) |             ) | ||||||
|             self.subdict["age"].append(result[person]["age"]) |             self.subdict["age"].append(result[person]["age"]) | ||||||
|             # gender is now reported as a list of dictionaries |             # Gender is now reported as a list of dictionaries. | ||||||
|             # each dict represents one face |             # Each dict represents one face. | ||||||
|             # each dict contains probability for Woman and Man |             # Each dict contains probability for Woman and Man. | ||||||
|             # take only the higher prob result for each dict |             # We take only the higher probability result for each dict. | ||||||
|             self.subdict["gender"].append(result[person]["gender"]) |             self.subdict["gender"].append(result[person]["gender"]) | ||||||
|             # race, emotion only detected if person does not wear mask |             # Race and emotion are only detected if a person does not wear a mask | ||||||
|             if result[person]["wears_mask"]: |             if result[person]["wears_mask"]: | ||||||
|                 self.subdict["race"].append(None) |                 self.subdict["race"].append(None) | ||||||
|                 self.subdict["emotion"].append(None) |                 self.subdict["emotion"].append(None) | ||||||
| @ -223,36 +256,28 @@ class EmotionDetector(utils.AnalysisMethod): | |||||||
|                 else: |                 else: | ||||||
|                     self.subdict["emotion"].append(None) |                     self.subdict["emotion"].append(None) | ||||||
|                     self.subdict["emotion (category)"].append(None) |                     self.subdict["emotion (category)"].append(None) | ||||||
| 
 |  | ||||||
|         return self.subdict |         return self.subdict | ||||||
| 
 | 
 | ||||||
|     def wears_mask(self, face: np.ndarray) -> bool: |     def wears_mask(self, face: np.ndarray) -> bool: | ||||||
|         global mask_detection_model |         """ | ||||||
|  |         Determines whether a face wears a mask. | ||||||
| 
 | 
 | ||||||
|         # Preprocess the face to match the assumptions of the face mask |         Args: | ||||||
|         # detection model |             face (np.ndarray): The face image array. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             bool: True if the face wears a mask, False otherwise. | ||||||
|  |         """ | ||||||
|  |         global mask_detection_model | ||||||
|  |         # Preprocess the face to match the assumptions of the face mask detection model | ||||||
|         face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) |         face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) | ||||||
|         face = cv2.resize(face, (224, 224)) |         face = cv2.resize(face, (224, 224)) | ||||||
|         face = img_to_array(face) |         face = img_to_array(face) | ||||||
|         face = preprocess_input(face) |         face = preprocess_input(face) | ||||||
|         face = np.expand_dims(face, axis=0) |         face = np.expand_dims(face, axis=0) | ||||||
| 
 |  | ||||||
|         # Lazily load the model |         # Lazily load the model | ||||||
|         mask_detection_model = load_model(face_mask_model.get()) |         mask_detection_model = load_model(face_mask_model.get()) | ||||||
| 
 |         # Run the model | ||||||
|         # Run the model (ignoring output) |  | ||||||
|         with NocatchOutput(): |  | ||||||
|         mask, without_mask = mask_detection_model.predict(face)[0] |         mask, without_mask = mask_detection_model.predict(face)[0] | ||||||
| 
 |  | ||||||
|         # Convert from np.bool_ to bool to later be able to serialize the result |         # Convert from np.bool_ to bool to later be able to serialize the result | ||||||
|         return bool(mask > without_mask) |         return bool(mask > without_mask) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class NocatchOutput(ipywidgets.Output): |  | ||||||
|     """An output container that suppresses output, but not exceptions |  | ||||||
| 
 |  | ||||||
|     Taken from https://github.com/jupyter-widgets/ipywidgets/issues/3208#issuecomment-1070836153 |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __exit__(self, *args, **kwargs): |  | ||||||
|         super().__exit__(*args, **kwargs) |  | ||||||
|  | |||||||
| @ -5,20 +5,22 @@ from ammico.objects_cvlib import init_default_objects | |||||||
| 
 | 
 | ||||||
| class ObjectDetectorClient(AnalysisMethod): | class ObjectDetectorClient(AnalysisMethod): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         # The detector is default to CVLib |         # The detector is set to CVLib by default | ||||||
|         # Here other libraries can be added |  | ||||||
|         self.detector = ObjectCVLib() |         self.detector = ObjectCVLib() | ||||||
| 
 | 
 | ||||||
|     def set_client_to_cvlib(self): |     def set_client_to_cvlib(self): | ||||||
|  |         """Set the object detection client to use CVLib.""" | ||||||
|         self.detector = ObjectCVLib() |         self.detector = ObjectCVLib() | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self, subdict=None): |     def analyse_image(self, subdict=None): | ||||||
|         """Localize objects in the local image. |         """Localize objects in the given image. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|         subdict: The dictionary for an image expression instance. |             subdict (dict): The dictionary for an image expression instance. | ||||||
|         """ |  | ||||||
| 
 | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated dictionary with object detection results. | ||||||
|  |         """ | ||||||
|         return self.detector.analyse_image(subdict) |         return self.detector.analyse_image(subdict) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -30,12 +32,23 @@ class ObjectDetector(AnalysisMethod): | |||||||
|         self.subdict.update(self.set_keys()) |         self.subdict.update(self.set_keys()) | ||||||
| 
 | 
 | ||||||
|     def set_keys(self): |     def set_keys(self): | ||||||
|  |         """Set the default object keys for analysis. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The dictionary with default object keys. | ||||||
|  |         """ | ||||||
|         return init_default_objects() |         return init_default_objects() | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self): |     def analyse_image(self): | ||||||
|  |         """Perform object detection on the image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated dictionary with object detection results. | ||||||
|  |         """ | ||||||
|         self.subdict = ObjectDetector.od_client.analyse_image(self.subdict) |         self.subdict = ObjectDetector.od_client.analyse_image(self.subdict) | ||||||
|         return self.subdict |         return self.subdict | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def set_client_to_cvlib(): |     def set_client_to_cvlib(): | ||||||
|  |         """Set the object detection client to use CVLib.""" | ||||||
|         ObjectDetector.od_client.set_client_to_cvlib() |         ObjectDetector.od_client.set_client_to_cvlib() | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| from ammico.utils import AnalysisMethod | from ammico.utils import AnalysisMethod | ||||||
| from torch import device, cuda, no_grad | from torch import cuda, no_grad | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from lavis.models import load_model_and_preprocess | from lavis.models import load_model_and_preprocess | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -49,8 +49,8 @@ def test_AnalysisExplorer(get_path): | |||||||
| 
 | 
 | ||||||
|     assert analysis_explorer_objects.update_picture(None) is None |     assert analysis_explorer_objects.update_picture(None) is None | ||||||
| 
 | 
 | ||||||
|     analysis_explorer_faces._right_output_analysis(None, all_options_dict, path_img_1) |     analysis_explorer_faces._right_output_analysis(all_options_dict, path_img_1) | ||||||
|     analysis_explorer_objects._right_output_analysis(None, all_options_dict, path_img_2) |     analysis_explorer_objects._right_output_analysis(all_options_dict, path_img_2) | ||||||
| 
 | 
 | ||||||
|     with pytest.raises(EnvironmentError): |     with pytest.raises(EnvironmentError): | ||||||
|         analysis_explorer_faces.run_server(port=8050) |         analysis_explorer_faces.run_server(port=8050) | ||||||
|  | |||||||
| @ -2,6 +2,14 @@ import ammico.faces as fc | |||||||
| import json | import json | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_set_keys(): | ||||||
|  |     ed = fc.EmotionDetector({}) | ||||||
|  |     assert ed.subdict["face"] == "No" | ||||||
|  |     assert ed.subdict["multiple_faces"] == "No" | ||||||
|  |     assert ed.subdict["wears_mask"] == ["No"] | ||||||
|  |     assert ed.subdict["emotion"] == [None] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_analyse_faces(get_path): | def test_analyse_faces(get_path): | ||||||
|     mydict = { |     mydict = { | ||||||
|         "filename": get_path + "IMG_2746.png", |         "filename": get_path + "IMG_2746.png", | ||||||
|  | |||||||
| @ -19,6 +19,14 @@ import os | |||||||
| 
 | 
 | ||||||
| class TextDetector(utils.AnalysisMethod): | class TextDetector(utils.AnalysisMethod): | ||||||
|     def __init__(self, subdict: dict, analyse_text: bool = False) -> None: |     def __init__(self, subdict: dict, analyse_text: bool = False) -> None: | ||||||
|  |         """Init text detection class. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             subdict (dict): Dictionary containing file name/path, and possibly previous | ||||||
|  |             analysis results from other modules. | ||||||
|  |             analyse_text (bool, optional): Decide if extracted text will be further subject | ||||||
|  |             to analysis. Defaults to False. | ||||||
|  |         """ | ||||||
|         super().__init__(subdict) |         super().__init__(subdict) | ||||||
|         self.subdict.update(self.set_keys()) |         self.subdict.update(self.set_keys()) | ||||||
|         self.translator = Translator() |         self.translator = Translator() | ||||||
| @ -28,10 +36,16 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|             self._initialize_textblob() |             self._initialize_textblob() | ||||||
| 
 | 
 | ||||||
|     def set_keys(self) -> dict: |     def set_keys(self) -> dict: | ||||||
|  |         """Set the default keys for text analysis. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The dictionary with default text keys. | ||||||
|  |         """ | ||||||
|         params = {"text": None, "text_language": None, "text_english": None} |         params = {"text": None, "text_language": None, "text_english": None} | ||||||
|         return params |         return params | ||||||
| 
 | 
 | ||||||
|     def _initialize_spacy(self): |     def _initialize_spacy(self): | ||||||
|  |         """Initialize the Spacy library for text analysis.""" | ||||||
|         try: |         try: | ||||||
|             self.nlp = spacy.load("en_core_web_md") |             self.nlp = spacy.load("en_core_web_md") | ||||||
|         except Exception: |         except Exception: | ||||||
| @ -40,12 +54,18 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|         self.nlp.add_pipe("spacytextblob") |         self.nlp.add_pipe("spacytextblob") | ||||||
| 
 | 
 | ||||||
|     def _initialize_textblob(self): |     def _initialize_textblob(self): | ||||||
|  |         """Initialize the TextBlob library for text analysis.""" | ||||||
|         try: |         try: | ||||||
|             TextBlob("Here") |             TextBlob("Here") | ||||||
|         except Exception: |         except Exception: | ||||||
|             download_corpora.main() |             download_corpora.main() | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self): |     def analyse_image(self) -> dict: | ||||||
|  |         """Perform text extraction and analysis of the text. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The updated dictionary with text analysis results. | ||||||
|  |         """ | ||||||
|         self.get_text_from_image() |         self.get_text_from_image() | ||||||
|         self.translate_text() |         self.translate_text() | ||||||
|         self.remove_linebreaks() |         self.remove_linebreaks() | ||||||
| @ -60,7 +80,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|         return self.subdict |         return self.subdict | ||||||
| 
 | 
 | ||||||
|     def get_text_from_image(self): |     def get_text_from_image(self): | ||||||
|         """Detects text on the image.""" |         """Detect text on the image using Google Cloud Vision API.""" | ||||||
|         path = self.subdict["filename"] |         path = self.subdict["filename"] | ||||||
|         try: |         try: | ||||||
|             client = vision.ImageAnnotatorClient() |             client = vision.ImageAnnotatorClient() | ||||||
| @ -92,6 +112,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def translate_text(self): |     def translate_text(self): | ||||||
|  |         """Translate the detected text to English using the Translator object.""" | ||||||
|         translated = self.translator.translate(self.subdict["text"]) |         translated = self.translator.translate(self.subdict["text"]) | ||||||
|         self.subdict["text_language"] = translated.src |         self.subdict["text_language"] = translated.src | ||||||
|         self.subdict["text_english"] = translated.text |         self.subdict["text_english"] = translated.text | ||||||
| @ -105,7 +126,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def _run_spacy(self): |     def _run_spacy(self): | ||||||
|         """Generate spacy doc object.""" |         """Generate Spacy doc object for further text analysis.""" | ||||||
|         self.doc = self.nlp(self.subdict["text_english"]) |         self.doc = self.nlp(self.subdict["text_english"]) | ||||||
| 
 | 
 | ||||||
|     def clean_text(self): |     def clean_text(self): | ||||||
| @ -118,10 +139,12 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|         self.subdict["text_clean"] = " ".join(templist).rstrip().lstrip() |         self.subdict["text_clean"] = " ".join(templist).rstrip().lstrip() | ||||||
| 
 | 
 | ||||||
|     def correct_spelling(self): |     def correct_spelling(self): | ||||||
|  |         """Correct the spelling of the English text using TextBlob.""" | ||||||
|         self.textblob = TextBlob(self.subdict["text_english"]) |         self.textblob = TextBlob(self.subdict["text_english"]) | ||||||
|         self.subdict["text_english_correct"] = str(self.textblob.correct()) |         self.subdict["text_english_correct"] = str(self.textblob.correct()) | ||||||
| 
 | 
 | ||||||
|     def sentiment_analysis(self): |     def sentiment_analysis(self): | ||||||
|  |         """Perform sentiment analysis on the text using SpacyTextBlob.""" | ||||||
|         # polarity is between [-1.0, 1.0] |         # polarity is between [-1.0, 1.0] | ||||||
|         self.subdict["polarity"] = self.doc._.blob.polarity |         self.subdict["polarity"] = self.doc._.blob.polarity | ||||||
|         # subjectivity is a float within the range [0.0, 1.0] |         # subjectivity is a float within the range [0.0, 1.0] | ||||||
| @ -129,6 +152,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|         self.subdict["subjectivity"] = self.doc._.blob.subjectivity |         self.subdict["subjectivity"] = self.doc._.blob.subjectivity | ||||||
| 
 | 
 | ||||||
|     def text_summary(self): |     def text_summary(self): | ||||||
|  |         """Generate a summary of the text using the Transformers pipeline.""" | ||||||
|         # use the transformers pipeline to summarize the text |         # use the transformers pipeline to summarize the text | ||||||
|         # use the current default model - 03/2023 |         # use the current default model - 03/2023 | ||||||
|         model_name = "sshleifer/distilbart-cnn-12-6" |         model_name = "sshleifer/distilbart-cnn-12-6" | ||||||
| @ -152,6 +176,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|             self.subdict["text_summary"] = None |             self.subdict["text_summary"] = None | ||||||
| 
 | 
 | ||||||
|     def text_sentiment_transformers(self): |     def text_sentiment_transformers(self): | ||||||
|  |         """Perform text classification for sentiment using the Transformers pipeline.""" | ||||||
|         # use the transformers pipeline for text classification |         # use the transformers pipeline for text classification | ||||||
|         # use the current default model - 03/2023 |         # use the current default model - 03/2023 | ||||||
|         model_name = "distilbert-base-uncased-finetuned-sst-2-english" |         model_name = "distilbert-base-uncased-finetuned-sst-2-english" | ||||||
| @ -167,6 +192,7 @@ class TextDetector(utils.AnalysisMethod): | |||||||
|         self.subdict["sentiment_score"] = result[0]["score"] |         self.subdict["sentiment_score"] = result[0]["score"] | ||||||
| 
 | 
 | ||||||
|     def text_ner(self): |     def text_ner(self): | ||||||
|  |         """Perform named entity recognition on the text using the Transformers pipeline.""" | ||||||
|         # use the transformers pipeline for named entity recognition |         # use the transformers pipeline for named entity recognition | ||||||
|         # use the current default model - 03/2023 |         # use the current default model - 03/2023 | ||||||
|         model_name = "dbmdz/bert-large-cased-finetuned-conll03-english" |         model_name = "dbmdz/bert-large-cased-finetuned-conll03-english" | ||||||
| @ -193,6 +219,15 @@ class PostprocessText: | |||||||
|         csv_path: str = None, |         csv_path: str = None, | ||||||
|         analyze_text: str = "text_english", |         analyze_text: str = "text_english", | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Initializes the PostprocessText class that handles the topic analysis. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             mydict (dict, optional): Dictionary with textual data. Defaults to None. | ||||||
|  |             use_csv (bool, optional): Flag indicating whether to use a CSV file. Defaults to False. | ||||||
|  |             csv_path (str, optional): Path to the CSV file. Required if `use_csv` is True. Defaults to None. | ||||||
|  |             analyze_text (str, optional): Key for the text field to analyze. Defaults to "text_english". | ||||||
|  |         """ | ||||||
|         self.use_csv = use_csv |         self.use_csv = use_csv | ||||||
|         if mydict: |         if mydict: | ||||||
|             print("Reading data from dict.") |             print("Reading data from dict.") | ||||||
| @ -209,8 +244,16 @@ class PostprocessText: | |||||||
|                              `csv_path`." |                              `csv_path`." | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def analyse_topic(self, return_topics: int = 3): |     def analyse_topic(self, return_topics: int = 3) -> tuple: | ||||||
|         """Topic analysis using BERTopic.""" |         """ | ||||||
|  |         Performs topic analysis using BERTopic. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             return_topics (int, optional): Number of topics to return. Defaults to 3. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             tuple: A tuple containing the topic model, topic dataframe, and most frequent topics. | ||||||
|  |         """ | ||||||
|         # load spacy pipeline |         # load spacy pipeline | ||||||
|         nlp = spacy.load( |         nlp = spacy.load( | ||||||
|             "en_core_web_md", |             "en_core_web_md", | ||||||
| @ -237,7 +280,16 @@ class PostprocessText: | |||||||
|             most_frequent_topics.append(self.topic_model.get_topic(i)) |             most_frequent_topics.append(self.topic_model.get_topic(i)) | ||||||
|         return self.topic_model, topic_df, most_frequent_topics |         return self.topic_model, topic_df, most_frequent_topics | ||||||
| 
 | 
 | ||||||
|     def get_text_dict(self, analyze_text): |     def get_text_dict(self, analyze_text: str) -> list: | ||||||
|  |         """ | ||||||
|  |         Extracts text from the provided dictionary. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             analyze_text (str): Key for the text field to analyze. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             list: A list of text extracted from the dictionary. | ||||||
|  |         """ | ||||||
|         # use dict to put text_english or text_summary in list |         # use dict to put text_english or text_summary in list | ||||||
|         list_text_english = [] |         list_text_english = [] | ||||||
|         for key in self.mydict.keys(): |         for key in self.mydict.keys(): | ||||||
| @ -251,7 +303,16 @@ class PostprocessText: | |||||||
|             list_text_english.append(self.mydict[key][analyze_text]) |             list_text_english.append(self.mydict[key][analyze_text]) | ||||||
|         return list_text_english |         return list_text_english | ||||||
| 
 | 
 | ||||||
|     def get_text_df(self, analyze_text): |     def get_text_df(self, analyze_text: str) -> list: | ||||||
|  |         """ | ||||||
|  |         Extracts text from the provided dataframe. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             analyze_text (str): Column name for the text field to analyze. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             list: A list of text extracted from the dataframe. | ||||||
|  |         """ | ||||||
|         # use csv file to obtain dataframe and put text_english or text_summary in list |         # use csv file to obtain dataframe and put text_english or text_summary in list | ||||||
|         # check that "text_english" or "text_summary" is there |         # check that "text_english" or "text_summary" is there | ||||||
|         if analyze_text not in self.df: |         if analyze_text not in self.df: | ||||||
| @ -262,19 +323,3 @@ class PostprocessText: | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         return self.df[analyze_text].tolist() |         return self.df[analyze_text].tolist() | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     images = utils.find_files( |  | ||||||
|         path="data/test-debug/101-200fullposts", |  | ||||||
|         limit=110, |  | ||||||
|     ) |  | ||||||
|     # images = ["data/test-debug/101-200fullposts/100638_mya.png"] |  | ||||||
|     print(images) |  | ||||||
|     mydict = utils.initialize_dict(images) |  | ||||||
|     os.environ[ |  | ||||||
|         "GOOGLE_APPLICATION_CREDENTIALS" |  | ||||||
|     ] = "data/misinformation-campaign-981aa55a3b13.json" |  | ||||||
|     for key in mydict: |  | ||||||
|         print(key) |  | ||||||
|         mydict[key] = TextDetector(mydict[key], analyse_text=True).analyse_image() |  | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import pooch | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DownloadResource: | class DownloadResource: | ||||||
|     """A remote resource that needs on demand downloading |     """A remote resource that needs on demand downloading. | ||||||
| 
 | 
 | ||||||
|     We use this as a wrapper to the pooch library. The wrapper registers |     We use this as a wrapper to the pooch library. The wrapper registers | ||||||
|     each data file and allows prefetching through the CLI entry point |     each data file and allows prefetching through the CLI entry point | ||||||
| @ -33,7 +33,7 @@ def ammico_prefetch_models(): | |||||||
| class AnalysisMethod: | class AnalysisMethod: | ||||||
|     """Base class to be inherited by all analysis methods.""" |     """Base class to be inherited by all analysis methods.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, subdict) -> None: |     def __init__(self, subdict: dict) -> None: | ||||||
|         self.subdict = subdict |         self.subdict = subdict | ||||||
|         # define keys that will be set by the analysis |         # define keys that will be set by the analysis | ||||||
| 
 | 
 | ||||||
| @ -44,35 +44,40 @@ class AnalysisMethod: | |||||||
|         raise NotImplementedError() |         raise NotImplementedError() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def find_files(path=None, pattern="*.png", recursive=True, limit=20): | def find_files( | ||||||
|  |     path: str = None, pattern: str = "*.png", recursive: bool = True, limit: int = 20 | ||||||
|  | ) -> list: | ||||||
|     """Find image files on the file system. |     """Find image files on the file system. | ||||||
| 
 | 
 | ||||||
|     :param path: |     Args: | ||||||
|         The base directory where we are looking for the images. Defaults |         path (str, optional): The base directory where we are looking for the images. Defaults | ||||||
|         to None, which uses the XDG data directory if set or the current |         to None, which uses the XDG data directory if set or the current | ||||||
|         working directory otherwise. |         working directory otherwise. | ||||||
|     :param pattern: |         pattern (str, optional): The naming pattern that the filename should match. Defaults to | ||||||
|         The naming pattern that the filename should match. Defaults to |  | ||||||
|         "*.png". Can be used to allow other patterns or to only include |         "*.png". Can be used to allow other patterns or to only include | ||||||
|         specific prefixes or suffixes. |         specific prefixes or suffixes. | ||||||
|     :param recursive: |         recursive (bool, optional): Whether to recurse into subdirectories. Default is set to False. | ||||||
|         Whether to recurse into subdirectories. |         limit (int, optional): The maximum number of images to be found. | ||||||
|     :param limit: |         Defaults to 20. To return all images, set to None. | ||||||
|         The maximum number of images to be found. Defaults to 20. | 
 | ||||||
|         To return all images, set to None. |     Returns: | ||||||
|  |         list: A list with all filenames including the path. | ||||||
|     """ |     """ | ||||||
|     if path is None: |     if path is None: | ||||||
|         path = os.environ.get("XDG_DATA_HOME", ".") |         path = os.environ.get("XDG_DATA_HOME", ".") | ||||||
| 
 |  | ||||||
|     result = list(glob.glob(f"{path}/{pattern}", recursive=recursive)) |     result = list(glob.glob(f"{path}/{pattern}", recursive=recursive)) | ||||||
| 
 |  | ||||||
|     if limit is not None: |     if limit is not None: | ||||||
|         result = result[:limit] |         result = result[:limit] | ||||||
| 
 |  | ||||||
|     return result |     return result | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def initialize_dict(filelist: list) -> dict: | def initialize_dict(filelist: list) -> dict: | ||||||
|  |     """Initialize the nested dictionary for all the found images. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         filelist (list): The list of files to be analyzed, including their paths. | ||||||
|  |     Returns: | ||||||
|  |         dict: The nested dictionary with all image ids and their paths.""" | ||||||
|     mydict = {} |     mydict = {} | ||||||
|     for img_path in filelist: |     for img_path in filelist: | ||||||
|         id_ = os.path.splitext(os.path.basename(img_path))[0] |         id_ = os.path.splitext(os.path.basename(img_path))[0] | ||||||
| @ -81,7 +86,7 @@ def initialize_dict(filelist: list) -> dict: | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def append_data_to_dict(mydict: dict) -> dict: | def append_data_to_dict(mydict: dict) -> dict: | ||||||
|     """Append entries from list of dictionaries to keys in global dict.""" |     """Append entries from nested dictionaries to keys in a global dict.""" | ||||||
| 
 | 
 | ||||||
|     # first initialize empty list for each key that is present |     # first initialize empty list for each key that is present | ||||||
|     outdict = {key: [] for key in list(mydict.values())[0].keys()} |     outdict = {key: [] for key in list(mydict.values())[0].keys()} | ||||||
| @ -98,6 +103,7 @@ def dump_df(mydict: dict) -> DataFrame: | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_interactive(): | def is_interactive(): | ||||||
|  |     """Check if we are running in an interactive environment.""" | ||||||
|     import __main__ as main |     import __main__ as main | ||||||
| 
 | 
 | ||||||
|     return not hasattr(main, "__file__") |     return not hasattr(main, "__file__") | ||||||
|  | |||||||
| @ -32,7 +32,6 @@ dependencies = [ | |||||||
|     "grpcio", |     "grpcio", | ||||||
|     "importlib_metadata", |     "importlib_metadata", | ||||||
|     "ipython", |     "ipython", | ||||||
|     "ipywidgets<8.0.5", |  | ||||||
|     "jupyter_dash", |     "jupyter_dash", | ||||||
|     "matplotlib", |     "matplotlib", | ||||||
|     "numpy<=1.23.4", |     "numpy<=1.23.4", | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 GwydionJon
						GwydionJon