Этот коммит содержится в:
Petr Andriushchenko 2023-02-24 10:47:52 +01:00
родитель 28014f563f
Коммит f4e47c105e
2 изменённых файлов: 27 добавлений и 13 удалений

Просмотреть файл

@ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod):
summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_base(self):
def load_model_base():
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="base_coco",
@ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod):
)
return summary_model, summary_vis_processors
def load_model_large(self):
def load_model_large():
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="large_coco",
@ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod):
)
return summary_model, summary_vis_processors
def load_model(model_type):
select_model = {
"base": SummaryDetector.load_model_base,
"large": SummaryDetector.load_model_large,
}
summary_model, summary_vis_processors = select_model[model_type]()
return summary_model, summary_vis_processors
def set_keys(self) -> dict:
params = {
"const_image_summary": None,
@ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod):
}
return params
def analyse_image(self, model_type):
def analyse_image(self, summary_model, summary_vis_processors):
select_model = {
"base": self.load_model_base,
"large": self.load_model_large,
}
summary_model, summary_vis_processors = select_model[model_type]()
path = self.subdict["filename"]
raw_image = Image.open(path).convert("RGB")
image = (
@ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod):
name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device
)
def analyse_questions(self, model_type, list_of_questions):
def analyse_questions(self, list_of_questions):
if len(list_of_questions) > 0:
path = self.subdict["filename"]
raw_image = Image.open(path).convert("RGB")
image = (
summary_VQA_vis_processors["eval"](raw_image)
self.summary_VQA_vis_processors["eval"](raw_image)
.unsqueeze(0)
.to(self.summary_device)
)

17
notebooks/image_summary.ipynb сгенерированный
Просмотреть файл

@ -40,7 +40,7 @@
"outputs": [],
"source": [
"images = mutils.find_files(\n",
" path=\"../data/Image_some_text/\",\n",
" path=\"../misinformation/test/data/\",\n",
" limit=1000,\n",
")"
]
@ -70,6 +70,15 @@
"## Create captions for images and directly write to csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -77,7 +86,9 @@
"outputs": [],
"source": [
"for key in mydict:\n",
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()"
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n",
" summary_model, summary_vis_processors\n",
" )"
]
},
{
@ -260,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.10.8"
},
"vscode": {
"interpreter": {