diff --git a/misinformation/summary.py b/misinformation/summary.py index a5fbc7a..60fd9cb 100644 --- a/misinformation/summary.py +++ b/misinformation/summary.py @@ -7,13 +7,14 @@ from lavis.models import load_model_and_preprocess class SummaryDetector(AnalysisMethod): def __init__(self, subdict: dict) -> None: super().__init__(subdict) - self.subdict.update(self.set_keys()) - self.image_summary = { - "const_image_summary": None, - "3_non-deterministic summary": None, - } summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summary_model, summary_vis_processors, _ = load_model_and_preprocess( + name="blip_caption", + model_type="base_coco", + is_eval=True, + device=summary_device, + ) def load_model_base(): summary_model, summary_vis_processors, _ = load_model_and_preprocess( @@ -41,14 +42,11 @@ class SummaryDetector(AnalysisMethod): summary_model, summary_vis_processors = select_model[model_type]() return summary_model, summary_vis_processors - def set_keys(self) -> dict: - params = { - "const_image_summary": None, - "3_non-deterministic summary": None, - } - return params + def analyse_image(self, summary_model=None, summary_vis_processors=None): - def analyse_image(self, summary_model, summary_vis_processors): + if summary_model is None and summary_vis_processors is None: + summary_model = SummaryDetector.summary_model + summary_vis_processors = SummaryDetector.summary_vis_processors path = self.subdict["filename"] raw_image = Image.open(path).convert("RGB") @@ -58,14 +56,12 @@ class SummaryDetector(AnalysisMethod): .to(self.summary_device) ) with torch.no_grad(): - self.image_summary["const_image_summary"] = summary_model.generate( + self.subdict["const_image_summary"] = summary_model.generate( {"image": image} )[0] - self.image_summary["3_non-deterministic summary"] = summary_model.generate( + self.subdict["3_non-deterministic summary"] = summary_model.generate( {"image": image}, use_nucleus_sampling=True, num_captions=3 ) - for key in self.image_summary: - self.subdict[key] = self.image_summary[key] return self.subdict ( @@ -99,10 +95,8 @@ class SummaryDetector(AnalysisMethod): ) for q, a in zip(question_batch, answers_batch): - self.image_summary[q] = a + self.subdict[q] = a - for key in self.image_summary: - self.subdict[key] = self.image_summary[key] else: print("Please, enter list of questions") return self.subdict