зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 21:46:04 +02:00
fixed new model in summary.py
Этот коммит содержится в:
родитель
28014f563f
Коммит
f4e47c105e
@ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod):
|
|||||||
|
|
||||||
summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def load_model_base(self):
|
def load_model_base():
|
||||||
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
||||||
name="blip_caption",
|
name="blip_caption",
|
||||||
model_type="base_coco",
|
model_type="base_coco",
|
||||||
@ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return summary_model, summary_vis_processors
|
return summary_model, summary_vis_processors
|
||||||
|
|
||||||
def load_model_large(self):
|
def load_model_large():
|
||||||
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
|
||||||
name="blip_caption",
|
name="blip_caption",
|
||||||
model_type="large_coco",
|
model_type="large_coco",
|
||||||
@ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return summary_model, summary_vis_processors
|
return summary_model, summary_vis_processors
|
||||||
|
|
||||||
|
def load_model(model_type):
|
||||||
|
select_model = {
|
||||||
|
"base": SummaryDetector.load_model_base,
|
||||||
|
"large": SummaryDetector.load_model_large,
|
||||||
|
}
|
||||||
|
summary_model, summary_vis_processors = select_model[model_type]()
|
||||||
|
return summary_model, summary_vis_processors
|
||||||
|
|
||||||
def set_keys(self) -> dict:
|
def set_keys(self) -> dict:
|
||||||
params = {
|
params = {
|
||||||
"const_image_summary": None,
|
"const_image_summary": None,
|
||||||
@ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod):
|
|||||||
}
|
}
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def analyse_image(self, model_type):
|
def analyse_image(self, summary_model, summary_vis_processors):
|
||||||
|
|
||||||
select_model = {
|
|
||||||
"base": self.load_model_base,
|
|
||||||
"large": self.load_model_large,
|
|
||||||
}
|
|
||||||
summary_model, summary_vis_processors = select_model[model_type]()
|
|
||||||
path = self.subdict["filename"]
|
path = self.subdict["filename"]
|
||||||
raw_image = Image.open(path).convert("RGB")
|
raw_image = Image.open(path).convert("RGB")
|
||||||
image = (
|
image = (
|
||||||
@ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod):
|
|||||||
name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device
|
name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device
|
||||||
)
|
)
|
||||||
|
|
||||||
def analyse_questions(self, model_type, list_of_questions):
|
def analyse_questions(self, list_of_questions):
|
||||||
|
|
||||||
if len(list_of_questions) > 0:
|
if len(list_of_questions) > 0:
|
||||||
path = self.subdict["filename"]
|
path = self.subdict["filename"]
|
||||||
raw_image = Image.open(path).convert("RGB")
|
raw_image = Image.open(path).convert("RGB")
|
||||||
image = (
|
image = (
|
||||||
summary_VQA_vis_processors["eval"](raw_image)
|
self.summary_VQA_vis_processors["eval"](raw_image)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.to(self.summary_device)
|
.to(self.summary_device)
|
||||||
)
|
)
|
||||||
|
|||||||
17
notebooks/image_summary.ipynb
сгенерированный
17
notebooks/image_summary.ipynb
сгенерированный
@ -40,7 +40,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"images = mutils.find_files(\n",
|
"images = mutils.find_files(\n",
|
||||||
" path=\"../data/Image_some_text/\",\n",
|
" path=\"../misinformation/test/data/\",\n",
|
||||||
" limit=1000,\n",
|
" limit=1000,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
@ -70,6 +70,15 @@
|
|||||||
"## Create captions for images and directly write to csv"
|
"## Create captions for images and directly write to csv"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@ -77,7 +86,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"for key in mydict:\n",
|
"for key in mydict:\n",
|
||||||
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()"
|
" mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n",
|
||||||
|
" summary_model, summary_vis_processors\n",
|
||||||
|
" )"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -260,7 +271,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.5"
|
"version": "3.10.8"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user