commented new heavy tests in test_multimodal_search.py

Этот коммит содержится в:
Petr Andriushchenko 2023-03-30 14:19:07 +02:00
родитель f7081d6878
Коммит ecc0d814bb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6

Просмотреть файл

@ -447,109 +447,109 @@ def test_parsing_images(
cuda.empty_cache() cuda.empty_cache()
def test_itm(): # def test_itm():
test_my_dict = { # test_my_dict = {
"IMG_2746": { # "IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png", # "filename": "../misinformation/test/data/IMG_2746.png",
"rank A bus": 1, # "rank A bus": 1,
"A bus": 0.15640679001808167, # "A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1, # "rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979, # "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
}, # },
"IMG_2809": { # "IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png", # "filename": "../misinformation/test/data/IMG_2809.png",
"rank A bus": 0, # "rank A bus": 0,
"A bus": 0.1970970332622528, # "A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0, # "rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851, # "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
}, # },
} # }
search_query3 = [ # search_query3 = [
{"text_input": "A bus"}, # {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, # {"image": "../misinformation/test/data/IMG_3758.png"},
] # ]
image_keys = ["IMG_2746", "IMG_2809"] # image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]] # sorted_list = [[1, 0], [1, 0]]
for itm_model in ["blip_base", "blip_large"]: # for itm_model in ["blip_base", "blip_large"]:
( # (
itm_scores, # itm_scores,
image_gradcam_with_itm, # image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering( # ) = ms.MultimodalSearch.image_text_match_reordering(
test_my_dict, # test_my_dict,
search_query3, # search_query3,
itm_model, # itm_model,
image_keys, # image_keys,
sorted_list, # sorted_list,
batch_size=1, # batch_size=1,
need_grad_cam=True, # need_grad_cam=True,
) # )
for i, itm in zip( # for i, itm in zip(
range(len(dict_itm_scores_for_blib[itm_model])), # range(len(dict_itm_scores_for_blib[itm_model])),
dict_itm_scores_for_blib[itm_model], # dict_itm_scores_for_blib[itm_model],
): # ):
assert ( # assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) # math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True # is True
) # )
for i, grad_cam in zip( # for i, grad_cam in zip(
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])), # range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
dict_image_gradcam_with_itm_for_blip[itm_model], # dict_image_gradcam_with_itm_for_blip[itm_model],
): # ):
assert ( # assert (
math.isclose( # math.isclose(
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i], # image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
grad_cam, # grad_cam,
rel_tol=10 * related_error, # rel_tol=10 * related_error,
) # )
is True # is True
) # )
del itm_scores, image_gradcam_with_itm # del itm_scores, image_gradcam_with_itm
cuda.empty_cache() # cuda.empty_cache()
def test_itm_blip2_coco(): # def test_itm_blip2_coco():
test_my_dict = { # test_my_dict = {
"IMG_2746": { # "IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png", # "filename": "../misinformation/test/data/IMG_2746.png",
"rank A bus": 1, # "rank A bus": 1,
"A bus": 0.15640679001808167, # "A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1, # "rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979, # "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
}, # },
"IMG_2809": { # "IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png", # "filename": "../misinformation/test/data/IMG_2809.png",
"rank A bus": 0, # "rank A bus": 0,
"A bus": 0.1970970332622528, # "A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0, # "rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851, # "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
}, # },
} # }
search_query3 = [ # search_query3 = [
{"text_input": "A bus"}, # {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, # {"image": "../misinformation/test/data/IMG_3758.png"},
] # ]
image_keys = ["IMG_2746", "IMG_2809"] # image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]] # sorted_list = [[1, 0], [1, 0]]
( # (
itm_scores, # itm_scores,
image_gradcam_with_itm, # image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering( # ) = ms.MultimodalSearch.image_text_match_reordering(
test_my_dict, # test_my_dict,
search_query3, # search_query3,
"blip2_coco", # "blip2_coco",
image_keys, # image_keys,
sorted_list, # sorted_list,
batch_size=1, # batch_size=1,
need_grad_cam=False, # need_grad_cam=False,
) # )
for i, itm in zip( # for i, itm in zip(
range(len(dict_itm_scores_for_blib["blip2_coco"])), # range(len(dict_itm_scores_for_blib["blip2_coco"])),
dict_itm_scores_for_blib["blip2_coco"], # dict_itm_scores_for_blib["blip2_coco"],
): # ):
assert ( # assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) # math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True # is True
) # )
del itm_scores, image_gradcam_with_itm # del itm_scores, image_gradcam_with_itm
cuda.empty_cache() # cuda.empty_cache()