diff --git a/ammico/display.py b/ammico/display.py index 3034df4..5248c2a 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -5,9 +5,9 @@ import ammico.colors as colors from ammico.utils import is_interactive import ammico.summary as summary import dash_renderjson -from dash import html, Input, Output, dcc, State -import jupyter_dash +from dash import html, Input, Output, dcc, State, Dash from PIL import Image +import dash_bootstrap_components as dbc COLOR_SCHEMES = [ @@ -37,7 +37,7 @@ class AnalysisExplorer: mydict (dict): A nested dictionary containing image data for all images. """ - self.app = jupyter_dash.JupyterDash(__name__) + self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) self.mydict = mydict self.theme = { "scheme": "monokai", @@ -63,36 +63,24 @@ class AnalysisExplorer: # Setup the layout app_layout = html.Div( [ - # Top - html.Div( - [self._top_file_explorer(mydict)], + # Top row, only file explorer + dbc.Row( + [dbc.Col(self._top_file_explorer(mydict))], id="Div_top", style={ "width": "30%", }, ), - # Middle - html.Div( - [self._middle_picture_frame()], - id="Div_middle", - style={ - "width": "50%", - "display": "inline-block", - "verticalAlign": "top", - }, - ), - # Right - html.Div( - [self._right_output_json()], - id="Div_right", - style={ - "width": "45%", - "display": "inline-block", - "verticalAlign": "top", - }, + # second row, middle picture and right output + dbc.Row( + [ + # first column: picture + dbc.Col(self._middle_picture_frame()), + dbc.Col(self._right_output_json()), + ] ), ], - style={"width": "95%", "display": "inline-block"}, + # style={"width": "95%", "display": "inline-block"}, ) self.app.layout = app_layout @@ -170,106 +158,100 @@ class AnalysisExplorer: def _create_setting_layout(self): settings_layout = html.Div( [ + # text summary start html.Div( id="settings_TextDetector", style={"display": "none"}, children=[ - dcc.Checklist( - ["Analyse text"], - ["Analyse text"], - id="setting_Text_analyse_text", - ), - html.Div( + dbc.Row( + dcc.Checklist( + ["Analyse text"], + ["Analyse text"], + id="setting_Text_analyse_text", + style={"margin-bottom": "10px"}, + ), + ), # row 1 + # text row 2 + dbc.Row( [ - html.Div( - "Select models for text_summary, text_sentiment, text_NER or leave blank for default:", - style={ - "height": "30px", - "margin-top": "5px", - }, + dbc.Col( + [ + html.P( + "Select models for text_summary, text_sentiment, text_NER or leave blank for default:", + # style={"width": "45%"}, + ), + ] + ), # + dbc.Col( + [ + html.P( + "Select model revision number for text_summary, text_sentiment, text_NER or leave blank for default:" + ), + ] ), - dcc.Input( - type="text", - id="setting_Text_model_names", - style={"height": "auto", "margin-bottom": "auto"}, - ), - ], - style={ - "width": "33%", - "display": "inline-block", - "margin-top": "10px", - }, - ), - html.Div( + ] + ), # row 2 + # input row 3 + dbc.Row( [ - html.Div( - "Select model revision number for text_summary, text_sentiment, text_NER or leave blank for default:", - style={ - "height": "30px", - "margin-top": "5px", - }, + dbc.Col( + dcc.Input( + type="text", + id="setting_Text_model_names", + style={"width": "100%"}, + ), ), - dcc.Input( - type="text", - id="setting_Text_revision_numbers", - style={"height": "auto", "margin-bottom": "auto"}, + dbc.Col( + dcc.Input( + type="text", + id="setting_Text_revision_numbers", + style={"width": "100%"}, + ), ), - ], - style={ - "width": "33%", - "display": "inline-block", - "margin-top": "10px", - }, - ), + ] + ), # row 3 ], - ), + ), # text summary end + # start emotion detector html.Div( id="settings_EmotionDetector", style={"display": "none"}, children=[ - html.Div( + dbc.Row( [ - html.Div( - "Emotion threshold", - style={"height": "30px", "margin-top": "5px"}, + dbc.Col( + [ + html.P("Emotion threshold"), + dcc.Input( + value=50, + type="number", + max=100, + min=0, + id="setting_Emotion_emotion_threshold", + style={"width": "100%"}, + ), + ], + align="start", ), - dcc.Input( - value=50, - type="number", - max=100, - min=0, - id="setting_Emotion_emotion_threshold", - style={"height": "auto", "margin-bottom": "auto"}, + dbc.Col( + [ + html.P("Race threshold"), + dcc.Input( + type="number", + value=50, + max=100, + min=0, + id="setting_Emotion_race_threshold", + style={"width": "100%"}, + ), + ], + align="start", ), ], - style={"width": "49%", "display": "inline-block"}, - ), - html.Div( - [ - html.Div( - "Race threshold", - style={ - "height": "30px", - "margin-top": "5px", - }, - ), - dcc.Input( - type="number", - value=50, - max=100, - min=0, - id="setting_Emotion_race_threshold", - style={"height": "auto", "margin-bottom": "auto"}, - ), - ], - style={ - "width": "49%", - "display": "inline-block", - "margin-top": "10px", - }, + style={"width": "100%"}, ), ], - ), + ), # end emotion detector html.Div( id="settings_ColorDetector", style={"display": "none"}, @@ -294,57 +276,45 @@ class AnalysisExplorer: id="settings_Summary_Detector", style={"display": "none"}, children=[ - html.Div( + dbc.Col( [ - dcc.Dropdown( - options=SUMMARY_ANALYSIS_TYPE, - value="summary_and_questions", - id="setting_Summary_analysis_type", - ) + dbc.Row([html.P("Analysis type:")]), + dbc.Row([html.P("Model type:")]), + dbc.Row([html.P("Analysis question:")]), ], - style={ - "width": "33%", - "display": "inline-block", - }, ), - html.Div( + dbc.Col( [ - dcc.Dropdown( - options=SUMMARY_MODEL, - value="base", - id="setting_Summary_model", - ) - ], - style={ - "width": "33%", - "display": "inline-block", - "margin-top": "10px", - }, - ), - html.Div( - [ - html.Div( - "Please enter a question", - style={ - "height": "50px", - "margin-top": "5px", - }, + dbc.Row( + dcc.Dropdown( + options=SUMMARY_ANALYSIS_TYPE, + value="summary_and_questions", + id="setting_Summary_analysis_type", + ) ), - dcc.Input( - type="text", - id="setting_Summary_list_of_questions", - style={"height": "auto", "margin-bottom": "auto"}, + dbc.Row( + dcc.Dropdown( + options=SUMMARY_MODEL, + value="base", + id="setting_Summary_model", + ) ), - ], - style={ - "width": "33%", - "display": "inline-block", - "margin-top": "10px", - }, + dbc.Row( + dcc.Input( + type="text", + id="setting_Summary_list_of_questions", + style={ + "height": "auto", + "margin-left": "11px", + }, + ), + ), + ] ), ], ), ], + style={"width": "100%", "display": "inline-block"}, ) return settings_layout @@ -356,38 +326,59 @@ class AnalysisExplorer: """ right_layout = html.Div( [ - dcc.Loading( - id="loading-2", - children=[ - html.Div( - [ - dcc.Dropdown( - options=[ - "TextDetector", - "ObjectDetector", - "EmotionDetector", - "SummaryDetector", - "ColorDetector", - ], - value="TextDetector", - id="Dropdown_select_Detector", - ), - html.Div( - children=[self._create_setting_layout()], - id="div_detector_args", - ), - html.Button("Run Detector", id="button_run"), - dash_renderjson.DashRenderjson( - id="right_json_viewer", - data={}, - max_depth=-1, - theme=self.theme, - invert_theme=True, - ), - ] - ) + dbc.Col( + [ + dbc.Row( + dcc.Dropdown( + options=[ + "TextDetector", + "ObjectDetector", + "EmotionDetector", + "SummaryDetector", + "ColorDetector", + ], + value="TextDetector", + id="Dropdown_select_Detector", + style={"width": "60%"}, + ), + justify="start", + ), + dbc.Row( + children=[self._create_setting_layout()], + id="div_detector_args", + justify="start", + ), + dbc.Row( + html.Button( + "Run Detector", + id="button_run", + style={ + "margin-top": "15px", + "margin-bottom": "15px", + "margin-left": "11px", + "width": "30%", + }, + ), + justify="start", + ), + dbc.Row( + dcc.Loading( + id="loading-2", + children=[ + dash_renderjson.DashRenderjson( + id="right_json_viewer", + data={}, + max_depth=-1, + theme=self.theme, + invert_theme=True, + ), + ], + type="circle", + ), + justify="start", + ), ], - type="circle", + align="start", ) ] ) @@ -396,18 +387,12 @@ class AnalysisExplorer: def run_server(self, port: int = 8050) -> None: """Run the Dash server to start the analysis explorer. - This method should only be called in an interactive environment like Jupyter notebooks. - Raises an EnvironmentError if not called in an interactive environment. Args: port (int, optional): The port number to run the server on (default: 8050). """ - if not is_interactive(): - raise EnvironmentError( - "Dash server should only be called in an interactive environment like Jupyter notebooks." - ) - self.app.run_server(debug=True, mode="inline", port=port) + self.app.run_server(debug=True, port=port) # Dash callbacks def update_picture(self, img_path: str): @@ -492,7 +477,9 @@ class AnalysisExplorer: # copy image so prvious runs don't leave their default values in the dict image_copy = self.mydict[image_id].copy() + # detector value is the string name of the chosen detector identify_function = identify_dict[detector_value] + if detector_value == "TextDetector": analyse_text = ( True if settings_text_analyse_text == ["Analyse text"] else False @@ -508,6 +495,7 @@ class AnalysisExplorer: else None, ) elif detector_value == "EmotionDetector": + print("test") detector_class = identify_function( image_copy, race_threshold=setting_emotion_race_threshold, diff --git a/ammico/faces.py b/ammico/faces.py index 7ab7b34..c9e1c43 100644 --- a/ammico/faces.py +++ b/ammico/faces.py @@ -184,6 +184,7 @@ class EmotionDetector(AnalysisMethod): """ # Find (multiple) faces in the image and cut them retinaface_model.get() + faces = RetinaFace.extract_faces(self.subdict["filename"]) # If no faces are found, we return empty keys if len(faces) == 0: diff --git a/ammico/test/test_display.py b/ammico/test/test_display.py index c91d0ed..fcaec50 100644 --- a/ammico/test/test_display.py +++ b/ammico/test/test_display.py @@ -105,23 +105,3 @@ def test_right_output_analysis_summary(get_AE, get_options): "base", "How many people are in the picture?", ) - - -def test_right_output_analysis_colors(get_AE, get_options): - get_AE._right_output_analysis( - 2, - get_options[3], - get_options[0], - "ColorDetector", - True, - None, - None, - 50, - 50, - "CIE 1976", - "summary_and_questions", - "base", - "How many people are in the picture?", - ) - with pytest.raises(EnvironmentError): - get_AE.run_server(port=8050) diff --git a/pyproject.toml b/pyproject.toml index be88b5d..357d1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,11 @@ classifiers = [ "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] + dependencies = [ "bertopic<=0.14.1", "cvlib", - "dash", + "dash>=2.11.0", "dash_renderjson", "deepface<=0.0.75", "googletrans==3.1.0a0", @@ -53,8 +54,8 @@ dependencies = [ "setuptools", "opencv-contrib-python", "dash", - "jupyter_dash", "dash_renderjson", + "dash_bootstrap_components", "colorgram.py", "webcolors", "colour-science",