diff --git a/misinformation/summary.py b/misinformation/summary.py index 60fd9cb..cca71af 100644 --- a/misinformation/summary.py +++ b/misinformation/summary.py @@ -16,32 +16,6 @@ class SummaryDetector(AnalysisMethod): device=summary_device, ) - def load_model_base(): - summary_model, summary_vis_processors, _ = load_model_and_preprocess( - name="blip_caption", - model_type="base_coco", - is_eval=True, - device=SummaryDetector.summary_device, - ) - return summary_model, summary_vis_processors - - def load_model_large(): - summary_model, summary_vis_processors, _ = load_model_and_preprocess( - name="blip_caption", - model_type="large_coco", - is_eval=True, - device=SummaryDetector.summary_device, - ) - return summary_model, summary_vis_processors - - def load_model(model_type): - select_model = { - "base": SummaryDetector.load_model_base, - "large": SummaryDetector.load_model_large, - } - summary_model, summary_vis_processors = select_model[model_type]() - return summary_model, summary_vis_processors - def analyse_image(self, summary_model=None, summary_vis_processors=None): if summary_model is None and summary_vis_processors is None: diff --git a/misinformation/utils.py b/misinformation/utils.py index 36c7690..4bb792c 100644 --- a/misinformation/utils.py +++ b/misinformation/utils.py @@ -2,6 +2,8 @@ import glob import os from pandas import DataFrame import pooch +import torch +from lavis.models import load_model_and_preprocess class DownloadResource: @@ -106,3 +108,34 @@ if __name__ == "__main__": outdict = append_data_to_dict(mydict) df = dump_df(outdict) print(df.head(10)) + + +def load_model_base(): + summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summary_model, summary_vis_processors, _ = load_model_and_preprocess( + name="blip_caption", + model_type="base_coco", + is_eval=True, + device=summary_device, + ) + return summary_model, summary_vis_processors + + +def load_model_large(): + summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summary_model, summary_vis_processors, _ = load_model_and_preprocess( + name="blip_caption", + model_type="large_coco", + is_eval=True, + device=summary_device, + ) + return summary_model, summary_vis_processors + + +def load_model(model_type): + select_model = { + "base": load_model_base, + "large": load_model_large, + } + summary_model, summary_vis_processors = select_model[model_type]() + return summary_model, summary_vis_processors diff --git a/notebooks/image_summary.ipynb b/notebooks/image_summary.ipynb index 7b5b69b..051138e 100644 --- a/notebooks/image_summary.ipynb +++ b/notebooks/image_summary.ipynb @@ -70,13 +70,21 @@ "## Create captions for images and directly write to csv" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here you can choose between two models: \"base\" or \"large\"" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")" + "summary_model, summary_vis_processors = mutils.load_model(\"base\")\n", + "# summary_model, summary_vis_processors = mutils.load_model(\"large\")" ] }, {