fixed varible name, uncomment test, exluded it from CI, fixed error in multimodal_search

Этот коммит содержится в:
Petr Andriushchenko 2023-03-31 13:35:04 +02:00
родитель 0ae872e750
Коммит f1aeeabd18
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
4 изменённых файлов: 141 добавлений и 131 удалений

2
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -32,7 +32,7 @@ jobs:
- name: Run pytest
run: |
cd misinformation
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml
python -m pytest -m "not gcv mem_cons" -svv --cov=. --cov-report=xml
- name: Upload coverage
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
uses: codecov/codecov-action@v3

Просмотреть файл

@ -13,8 +13,6 @@ from IPython.display import display
from lavis.models import load_model_and_preprocess, load_model, BlipBase
from lavis.processors import load_processor
# from memory_profiler import profile
class MultimodalSearch(AnalysisMethod):
def __init__(self, subdict: dict) -> None:
@ -382,7 +380,6 @@ class MultimodalSearch(AnalysisMethod):
def compute_gradcam_batch(
self,
itm_model_type,
model,
visual_input,
text_input,
@ -456,12 +453,12 @@ class MultimodalSearch(AnalysisMethod):
att_map -= att_map.min()
att_map /= att_map.max()
cmap = plt.get_cmap("jet")
att_mapV = cmap(att_map)
att_mapV = np.delete(att_mapV, 3, 2)
att_mapv = cmap(att_map)
att_mapv = np.delete(att_mapv, 3, 2)
if overlap:
att_map = (
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
)
return att_map
@ -498,7 +495,6 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
# @profile
def image_text_match_reordering(
self,
search_query,
@ -518,6 +514,7 @@ class MultimodalSearch(AnalysisMethod):
"blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
}
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer()
@ -557,7 +554,6 @@ class MultimodalSearch(AnalysisMethod):
if need_grad_cam:
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
self,
itm_model_type,
itm_model,
images,
queries_batch,
@ -618,20 +614,20 @@ class MultimodalSearch(AnalysisMethod):
image_gradcam_with_itm[
list(search_query[index_text_query].values())[0]
] = localimage_gradcam_with_itm
del (
itm_model,
vis_processor_itm,
text_processor,
raw_images,
images,
tokenizer,
queries_batch,
queries_tok_batch,
itm_score,
)
if need_grad_cam:
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
torch.cuda.empty_cache()
del (
itm_model,
vis_processor_itm,
text_processor,
raw_images,
images,
tokenizer,
queries_batch,
queries_tok_batch,
itm_score,
)
if need_grad_cam:
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False):

Просмотреть файл

@ -447,109 +447,111 @@ def test_parsing_images(
cuda.empty_cache()
# def test_itm():
# test_my_dict = {
# "IMG_2746": {
# "filename": "../misinformation/test/data/IMG_2746.png",
# "rank A bus": 1,
# "A bus": 0.15640679001808167,
# "rank ../misinformation/test/data/IMG_3758.png": 1,
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
# },
# "IMG_2809": {
# "filename": "../misinformation/test/data/IMG_2809.png",
# "rank A bus": 0,
# "A bus": 0.1970970332622528,
# "rank ../misinformation/test/data/IMG_3758.png": 0,
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
# },
# }
# search_query3 = [
# {"text_input": "A bus"},
# {"image": "../misinformation/test/data/IMG_3758.png"},
# ]
# image_keys = ["IMG_2746", "IMG_2809"]
# sorted_list = [[1, 0], [1, 0]]
# for itm_model in ["blip_base", "blip_large"]:
# (
# itm_scores,
# image_gradcam_with_itm,
# ) = ms.MultimodalSearch.image_text_match_reordering(
# test_my_dict,
# search_query3,
# itm_model,
# image_keys,
# sorted_list,
# batch_size=1,
# need_grad_cam=True,
# )
# for i, itm in zip(
# range(len(dict_itm_scores_for_blib[itm_model])),
# dict_itm_scores_for_blib[itm_model],
# ):
# assert (
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
# is True
# )
# for i, grad_cam in zip(
# range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
# dict_image_gradcam_with_itm_for_blip[itm_model],
# ):
# assert (
# math.isclose(
# image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
# grad_cam,
# rel_tol=10 * related_error,
# )
# is True
# )
# del itm_scores, image_gradcam_with_itm
# cuda.empty_cache()
@pytest.mark.long
def test_itm():
test_my_dict = {
"IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
},
}
search_query3 = [
{"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
for itm_model in ["blip_base", "blip_large"]:
(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
test_my_dict,
search_query3,
itm_model,
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=True,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib[itm_model])),
dict_itm_scores_for_blib[itm_model],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
for i, grad_cam in zip(
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
dict_image_gradcam_with_itm_for_blip[itm_model],
):
assert (
math.isclose(
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
grad_cam,
rel_tol=10 * related_error,
)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()
# def test_itm_blip2_coco():
# test_my_dict = {
# "IMG_2746": {
# "filename": "../misinformation/test/data/IMG_2746.png",
# "rank A bus": 1,
# "A bus": 0.15640679001808167,
# "rank ../misinformation/test/data/IMG_3758.png": 1,
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
# },
# "IMG_2809": {
# "filename": "../misinformation/test/data/IMG_2809.png",
# "rank A bus": 0,
# "A bus": 0.1970970332622528,
# "rank ../misinformation/test/data/IMG_3758.png": 0,
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
# },
# }
# search_query3 = [
# {"text_input": "A bus"},
# {"image": "../misinformation/test/data/IMG_3758.png"},
# ]
# image_keys = ["IMG_2746", "IMG_2809"]
# sorted_list = [[1, 0], [1, 0]]
@pytest.mark.long
def test_itm_blip2_coco():
test_my_dict = {
"IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
},
}
search_query3 = [
{"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
# (
# itm_scores,
# image_gradcam_with_itm,
# ) = ms.MultimodalSearch.image_text_match_reordering(
# test_my_dict,
# search_query3,
# "blip2_coco",
# image_keys,
# sorted_list,
# batch_size=1,
# need_grad_cam=False,
# )
# for i, itm in zip(
# range(len(dict_itm_scores_for_blib["blip2_coco"])),
# dict_itm_scores_for_blib["blip2_coco"],
# ):
# assert (
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
# is True
# )
# del itm_scores, image_gradcam_with_itm
# cuda.empty_cache()
(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
test_my_dict,
search_query3,
"blip2_coco",
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=False,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib["blip2_coco"])),
dict_itm_scores_for_blib["blip2_coco"],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()

22
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -47,7 +47,7 @@
"outputs": [],
"source": [
"images = misinformation.utils.find_files(\n",
" path=\"../data/Image_some_text/\",\n",
" path=\"../data/images/\",\n",
" limit=10,\n",
")"
]
@ -64,6 +64,18 @@
"mydict = misinformation.utils.initialize_dict(images)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c66aec87-ede7-4985-912e-3ca29245ebf2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict"
]
},
{
"cell_type": "markdown",
"id": "987540a8-d800-4c70-a76b-7bfabaf123fa",
@ -143,7 +155,7 @@
"# ) = ms.MultimodalSearch.parsing_images(\n",
"# mydict,\n",
"# model_type,\n",
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
"# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
"# )"
]
},
@ -251,7 +263,7 @@
"source": [
"ms.MultimodalSearch.show_results(\n",
" mydict,\n",
" search_query3[2],\n",
" search_query3[0],\n",
")"
]
},
@ -405,7 +417,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -419,7 +431,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.0"
}
},
"nbformat": 4,