зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 13:36:04 +02:00
fixed new model in summary.py
Этот коммит содержится в:
родитель
28014f563f
Коммит
f4e47c105e
@ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod):
|
||||
|
||||
summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_model_base(self):
|
||||
def load_model_base():
|
||||
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
||||
name="blip_caption",
|
||||
model_type="base_coco",
|
||||
@ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod):
|
||||
)
|
||||
return summary_model, summary_vis_processors
|
||||
|
||||
def load_model_large(self):
|
||||
def load_model_large():
|
||||
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
||||
name="blip_caption",
|
||||
model_type="large_coco",
|
||||
@ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod):
|
||||
)
|
||||
return summary_model, summary_vis_processors
|
||||
|
||||
def load_model(model_type):
|
||||
select_model = {
|
||||
"base": SummaryDetector.load_model_base,
|
||||
"large": SummaryDetector.load_model_large,
|
||||
}
|
||||
summary_model, summary_vis_processors = select_model[model_type]()
|
||||
return summary_model, summary_vis_processors
|
||||
|
||||
def set_keys(self) -> dict:
|
||||
params = {
|
||||
"const_image_summary": None,
|
||||
@ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod):
|
||||
}
|
||||
return params
|
||||
|
||||
def analyse_image(self, model_type):
|
||||
def analyse_image(self, summary_model, summary_vis_processors):
|
||||
|
||||
select_model = {
|
||||
"base": self.load_model_base,
|
||||
"large": self.load_model_large,
|
||||
}
|
||||
summary_model, summary_vis_processors = select_model[model_type]()
|
||||
path = self.subdict["filename"]
|
||||
raw_image = Image.open(path).convert("RGB")
|
||||
image = (
|
||||
@ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod):
|
||||
name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device
|
||||
)
|
||||
|
||||
def analyse_questions(self, model_type, list_of_questions):
|
||||
def analyse_questions(self, list_of_questions):
|
||||
|
||||
if len(list_of_questions) > 0:
|
||||
path = self.subdict["filename"]
|
||||
raw_image = Image.open(path).convert("RGB")
|
||||
image = (
|
||||
summary_VQA_vis_processors["eval"](raw_image)
|
||||
self.summary_VQA_vis_processors["eval"](raw_image)
|
||||
.unsqueeze(0)
|
||||
.to(self.summary_device)
|
||||
)
|
||||
|
||||
17
notebooks/image_summary.ipynb
сгенерированный
17
notebooks/image_summary.ipynb
сгенерированный
@ -40,7 +40,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"images = mutils.find_files(\n",
|
||||
" path=\"../data/Image_some_text/\",\n",
|
||||
" path=\"../misinformation/test/data/\",\n",
|
||||
" limit=1000,\n",
|
||||
")"
|
||||
]
|
||||
@ -70,6 +70,15 @@
|
||||
"## Create captions for images and directly write to csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -77,7 +86,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for key in mydict:\n",
|
||||
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()"
|
||||
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n",
|
||||
" summary_model, summary_vis_processors\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -260,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.5"
|
||||
"version": "3.10.8"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user