зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
fixed varible name, uncomment test, exluded it from CI, fixed error in multimodal_search
Этот коммит содержится в:
родитель
0ae872e750
Коммит
f1aeeabd18
2
.github/workflows/ci.yml
поставляемый
2
.github/workflows/ci.yml
поставляемый
@ -32,7 +32,7 @@ jobs:
|
||||
- name: Run pytest
|
||||
run: |
|
||||
cd misinformation
|
||||
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml
|
||||
python -m pytest -m "not gcv mem_cons" -svv --cov=. --cov-report=xml
|
||||
- name: Upload coverage
|
||||
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
|
||||
uses: codecov/codecov-action@v3
|
||||
|
||||
@ -13,8 +13,6 @@ from IPython.display import display
|
||||
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||
from lavis.processors import load_processor
|
||||
|
||||
# from memory_profiler import profile
|
||||
|
||||
|
||||
class MultimodalSearch(AnalysisMethod):
|
||||
def __init__(self, subdict: dict) -> None:
|
||||
@ -382,7 +380,6 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
def compute_gradcam_batch(
|
||||
self,
|
||||
itm_model_type,
|
||||
model,
|
||||
visual_input,
|
||||
text_input,
|
||||
@ -456,12 +453,12 @@ class MultimodalSearch(AnalysisMethod):
|
||||
att_map -= att_map.min()
|
||||
att_map /= att_map.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
att_mapV = cmap(att_map)
|
||||
att_mapV = np.delete(att_mapV, 3, 2)
|
||||
att_mapv = cmap(att_map)
|
||||
att_mapv = np.delete(att_mapv, 3, 2)
|
||||
if overlap:
|
||||
att_map = (
|
||||
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
|
||||
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV
|
||||
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
|
||||
)
|
||||
return att_map
|
||||
|
||||
@ -498,7 +495,6 @@ class MultimodalSearch(AnalysisMethod):
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
# @profile
|
||||
def image_text_match_reordering(
|
||||
self,
|
||||
search_query,
|
||||
@ -518,6 +514,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||
}
|
||||
|
||||
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
|
||||
text_processor = load_processor("blip_caption")
|
||||
tokenizer = BlipBase.init_tokenizer()
|
||||
@ -557,7 +554,6 @@ class MultimodalSearch(AnalysisMethod):
|
||||
if need_grad_cam:
|
||||
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||
self,
|
||||
itm_model_type,
|
||||
itm_model,
|
||||
images,
|
||||
queries_batch,
|
||||
@ -618,20 +614,20 @@ class MultimodalSearch(AnalysisMethod):
|
||||
image_gradcam_with_itm[
|
||||
list(search_query[index_text_query].values())[0]
|
||||
] = localimage_gradcam_with_itm
|
||||
del (
|
||||
itm_model,
|
||||
vis_processor_itm,
|
||||
text_processor,
|
||||
raw_images,
|
||||
images,
|
||||
tokenizer,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
itm_score,
|
||||
)
|
||||
if need_grad_cam:
|
||||
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
|
||||
torch.cuda.empty_cache()
|
||||
del (
|
||||
itm_model,
|
||||
vis_processor_itm,
|
||||
text_processor,
|
||||
raw_images,
|
||||
images,
|
||||
tokenizer,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
itm_score,
|
||||
)
|
||||
if need_grad_cam:
|
||||
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
|
||||
torch.cuda.empty_cache()
|
||||
return itm_scores2, image_gradcam_with_itm
|
||||
|
||||
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||
|
||||
@ -447,109 +447,111 @@ def test_parsing_images(
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
# def test_itm():
|
||||
# test_my_dict = {
|
||||
# "IMG_2746": {
|
||||
# "filename": "../misinformation/test/data/IMG_2746.png",
|
||||
# "rank A bus": 1,
|
||||
# "A bus": 0.15640679001808167,
|
||||
# "rank ../misinformation/test/data/IMG_3758.png": 1,
|
||||
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
|
||||
# },
|
||||
# "IMG_2809": {
|
||||
# "filename": "../misinformation/test/data/IMG_2809.png",
|
||||
# "rank A bus": 0,
|
||||
# "A bus": 0.1970970332622528,
|
||||
# "rank ../misinformation/test/data/IMG_3758.png": 0,
|
||||
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
|
||||
# },
|
||||
# }
|
||||
# search_query3 = [
|
||||
# {"text_input": "A bus"},
|
||||
# {"image": "../misinformation/test/data/IMG_3758.png"},
|
||||
# ]
|
||||
# image_keys = ["IMG_2746", "IMG_2809"]
|
||||
# sorted_list = [[1, 0], [1, 0]]
|
||||
# for itm_model in ["blip_base", "blip_large"]:
|
||||
# (
|
||||
# itm_scores,
|
||||
# image_gradcam_with_itm,
|
||||
# ) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
# test_my_dict,
|
||||
# search_query3,
|
||||
# itm_model,
|
||||
# image_keys,
|
||||
# sorted_list,
|
||||
# batch_size=1,
|
||||
# need_grad_cam=True,
|
||||
# )
|
||||
# for i, itm in zip(
|
||||
# range(len(dict_itm_scores_for_blib[itm_model])),
|
||||
# dict_itm_scores_for_blib[itm_model],
|
||||
# ):
|
||||
# assert (
|
||||
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
# is True
|
||||
# )
|
||||
# for i, grad_cam in zip(
|
||||
# range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
|
||||
# dict_image_gradcam_with_itm_for_blip[itm_model],
|
||||
# ):
|
||||
# assert (
|
||||
# math.isclose(
|
||||
# image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
|
||||
# grad_cam,
|
||||
# rel_tol=10 * related_error,
|
||||
# )
|
||||
# is True
|
||||
# )
|
||||
# del itm_scores, image_gradcam_with_itm
|
||||
# cuda.empty_cache()
|
||||
@pytest.mark.long
|
||||
def test_itm():
|
||||
test_my_dict = {
|
||||
"IMG_2746": {
|
||||
"filename": "../misinformation/test/data/IMG_2746.png",
|
||||
"rank A bus": 1,
|
||||
"A bus": 0.15640679001808167,
|
||||
"rank ../misinformation/test/data/IMG_3758.png": 1,
|
||||
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
|
||||
},
|
||||
"IMG_2809": {
|
||||
"filename": "../misinformation/test/data/IMG_2809.png",
|
||||
"rank A bus": 0,
|
||||
"A bus": 0.1970970332622528,
|
||||
"rank ../misinformation/test/data/IMG_3758.png": 0,
|
||||
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
|
||||
},
|
||||
}
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": "../misinformation/test/data/IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
for itm_model in ["blip_base", "blip_large"]:
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
test_my_dict,
|
||||
search_query3,
|
||||
itm_model,
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=True,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib[itm_model])),
|
||||
dict_itm_scores_for_blib[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
for i, grad_cam in zip(
|
||||
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
|
||||
dict_image_gradcam_with_itm_for_blip[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(
|
||||
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
|
||||
grad_cam,
|
||||
rel_tol=10 * related_error,
|
||||
)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
# def test_itm_blip2_coco():
|
||||
# test_my_dict = {
|
||||
# "IMG_2746": {
|
||||
# "filename": "../misinformation/test/data/IMG_2746.png",
|
||||
# "rank A bus": 1,
|
||||
# "A bus": 0.15640679001808167,
|
||||
# "rank ../misinformation/test/data/IMG_3758.png": 1,
|
||||
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
|
||||
# },
|
||||
# "IMG_2809": {
|
||||
# "filename": "../misinformation/test/data/IMG_2809.png",
|
||||
# "rank A bus": 0,
|
||||
# "A bus": 0.1970970332622528,
|
||||
# "rank ../misinformation/test/data/IMG_3758.png": 0,
|
||||
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
|
||||
# },
|
||||
# }
|
||||
# search_query3 = [
|
||||
# {"text_input": "A bus"},
|
||||
# {"image": "../misinformation/test/data/IMG_3758.png"},
|
||||
# ]
|
||||
# image_keys = ["IMG_2746", "IMG_2809"]
|
||||
# sorted_list = [[1, 0], [1, 0]]
|
||||
@pytest.mark.long
|
||||
def test_itm_blip2_coco():
|
||||
test_my_dict = {
|
||||
"IMG_2746": {
|
||||
"filename": "../misinformation/test/data/IMG_2746.png",
|
||||
"rank A bus": 1,
|
||||
"A bus": 0.15640679001808167,
|
||||
"rank ../misinformation/test/data/IMG_3758.png": 1,
|
||||
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
|
||||
},
|
||||
"IMG_2809": {
|
||||
"filename": "../misinformation/test/data/IMG_2809.png",
|
||||
"rank A bus": 0,
|
||||
"A bus": 0.1970970332622528,
|
||||
"rank ../misinformation/test/data/IMG_3758.png": 0,
|
||||
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
|
||||
},
|
||||
}
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": "../misinformation/test/data/IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
|
||||
# (
|
||||
# itm_scores,
|
||||
# image_gradcam_with_itm,
|
||||
# ) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
# test_my_dict,
|
||||
# search_query3,
|
||||
# "blip2_coco",
|
||||
# image_keys,
|
||||
# sorted_list,
|
||||
# batch_size=1,
|
||||
# need_grad_cam=False,
|
||||
# )
|
||||
# for i, itm in zip(
|
||||
# range(len(dict_itm_scores_for_blib["blip2_coco"])),
|
||||
# dict_itm_scores_for_blib["blip2_coco"],
|
||||
# ):
|
||||
# assert (
|
||||
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
# is True
|
||||
# )
|
||||
# del itm_scores, image_gradcam_with_itm
|
||||
# cuda.empty_cache()
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
test_my_dict,
|
||||
search_query3,
|
||||
"blip2_coco",
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib["blip2_coco"])),
|
||||
dict_itm_scores_for_blib["blip2_coco"],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
|
||||
22
notebooks/multimodal_search.ipynb
сгенерированный
22
notebooks/multimodal_search.ipynb
сгенерированный
@ -47,7 +47,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"images = misinformation.utils.find_files(\n",
|
||||
" path=\"../data/Image_some_text/\",\n",
|
||||
" path=\"../data/images/\",\n",
|
||||
" limit=10,\n",
|
||||
")"
|
||||
]
|
||||
@ -64,6 +64,18 @@
|
||||
"mydict = misinformation.utils.initialize_dict(images)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c66aec87-ede7-4985-912e-3ca29245ebf2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "987540a8-d800-4c70-a76b-7bfabaf123fa",
|
||||
@ -143,7 +155,7 @@
|
||||
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
||||
"# mydict,\n",
|
||||
"# model_type,\n",
|
||||
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
|
||||
"# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
@ -251,7 +263,7 @@
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict,\n",
|
||||
" search_query3[2],\n",
|
||||
" search_query3[0],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -405,7 +417,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -419,7 +431,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.9.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user