Этот коммит содержится в:
Petr Andriushchenko 2023-02-24 10:47:52 +01:00
родитель 28014f563f
Коммит f4e47c105e
2 изменённых файлов: 27 добавлений и 13 удалений

Просмотреть файл

@ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod):
summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_base(self): def load_model_base():
summary_model, summary_vis_processors, _ = load_model_and_preprocess( summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption", name="blip_caption",
model_type="base_coco", model_type="base_coco",
@ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod):
) )
return summary_model, summary_vis_processors return summary_model, summary_vis_processors
def load_model_large(self): def load_model_large():
summary_model, summary_vis_processors, _ = load_model_and_preprocess( summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption", name="blip_caption",
model_type="large_coco", model_type="large_coco",
@ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod):
) )
return summary_model, summary_vis_processors return summary_model, summary_vis_processors
def load_model(model_type):
select_model = {
"base": SummaryDetector.load_model_base,
"large": SummaryDetector.load_model_large,
}
summary_model, summary_vis_processors = select_model[model_type]()
return summary_model, summary_vis_processors
def set_keys(self) -> dict: def set_keys(self) -> dict:
params = { params = {
"const_image_summary": None, "const_image_summary": None,
@ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod):
} }
return params return params
def analyse_image(self, model_type): def analyse_image(self, summary_model, summary_vis_processors):
select_model = {
"base": self.load_model_base,
"large": self.load_model_large,
}
summary_model, summary_vis_processors = select_model[model_type]()
path = self.subdict["filename"] path = self.subdict["filename"]
raw_image = Image.open(path).convert("RGB") raw_image = Image.open(path).convert("RGB")
image = ( image = (
@ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod):
name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device
) )
def analyse_questions(self, model_type, list_of_questions): def analyse_questions(self, list_of_questions):
if len(list_of_questions) > 0: if len(list_of_questions) > 0:
path = self.subdict["filename"] path = self.subdict["filename"]
raw_image = Image.open(path).convert("RGB") raw_image = Image.open(path).convert("RGB")
image = ( image = (
summary_VQA_vis_processors["eval"](raw_image) self.summary_VQA_vis_processors["eval"](raw_image)
.unsqueeze(0) .unsqueeze(0)
.to(self.summary_device) .to(self.summary_device)
) )

17
notebooks/image_summary.ipynb сгенерированный
Просмотреть файл

@ -40,7 +40,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"images = mutils.find_files(\n", "images = mutils.find_files(\n",
" path=\"../data/Image_some_text/\",\n", " path=\"../misinformation/test/data/\",\n",
" limit=1000,\n", " limit=1000,\n",
")" ")"
] ]
@ -70,6 +70,15 @@
"## Create captions for images and directly write to csv" "## Create captions for images and directly write to csv"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -77,7 +86,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"for key in mydict:\n", "for key in mydict:\n",
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()" " mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n",
" summary_model, summary_vis_processors\n",
" )"
] ]
}, },
{ {
@ -260,7 +271,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.5" "version": "3.10.8"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {