change absolute paths to get_path fixture in test_multimodal_search, added get_testdict fixture

Этот коммит содержится в:
Petr Andriushchenko 2023-04-18 17:00:31 +02:00
родитель 35f63f0b74
Коммит acf20a4f21
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
2 изменённых файлов: 37 добавлений и 31 удалений

Просмотреть файл

@ -16,3 +16,12 @@ def set_environ(request):
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
)
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
@pytest.fixture
def get_testdict(get_path):
testdict = {
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
}
return testdict

Просмотреть файл

@ -5,22 +5,17 @@ import numpy
from torch import device, cuda
import misinformation.multimodal_search as ms
testdict = {
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
"IMG_2809": {"filename": "./test/data/IMG_2809.png"},
}
related_error = 1e-2
gpu_is_not_available = not cuda.is_available()
cuda.empty_cache()
def test_read_img():
def test_read_img(get_testdict):
my_dict = {}
test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"])
test_img = ms.MultimodalSearch.read_img(
my_dict, get_testdict["IMG_2746"]["filename"]
)
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
@ -354,6 +349,8 @@ def test_parsing_images(
pre_extracted_feature_text,
pre_simularity,
pre_sorted,
get_path,
get_testdict,
tmp_path,
):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
@ -365,7 +362,7 @@ def test_parsing_images(
_,
features_image_stacked,
) = ms.MultimodalSearch.parsing_images(
testdict, pre_model, path_to_saved_tensors=tmp_path
get_testdict, pre_model, path_to_saved_tensors=tmp_path
)
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
@ -374,7 +371,7 @@ def test_parsing_images(
is True
)
test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB")
test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
test_querry = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
)
@ -390,10 +387,10 @@ def test_parsing_images(
search_query = [
{"text_input": test_querry},
{"image": testdict["IMG_2746"]["filename"]},
{"image": get_testdict["IMG_2746"]["filename"]},
]
multi_features_stacked = ms.MultimodalSearch.querys_processing(
testdict, search_query, model, txt_processor, vis_processor, pre_model
get_testdict, search_query, model, txt_processor, vis_processor, pre_model
)
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
@ -410,11 +407,11 @@ def test_parsing_images(
search_query2 = [
{"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"},
{"image": get_path + "IMG_3758.png"},
]
similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
testdict,
get_testdict,
model,
vis_processor,
txt_processor,
@ -448,26 +445,26 @@ def test_parsing_images(
@pytest.mark.long
def test_itm():
def test_itm(get_path):
test_my_dict = {
"IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png",
"filename": get_path + "IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
"rank " + get_path + "IMG_3758.png": 1,
get_path + "IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png",
"filename": get_path + "IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
"rank " + get_path + "IMG_3758.png": 0,
get_path + "IMG_3758.png": 0.8907483816146851,
},
}
search_query3 = [
{"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
@ -509,26 +506,26 @@ def test_itm():
@pytest.mark.long
def test_itm_blip2_coco():
def test_itm_blip2_coco(get_path):
test_my_dict = {
"IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png",
"filename": get_path + "IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
"rank " + get_path + "IMG_3758.png": 1,
get_path + "IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png",
"filename": get_path + "IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
"rank " + get_path + "IMG_3758.png": 0,
get_path + "IMG_3758.png": 0.8907483816146851,
},
}
search_query3 = [
{"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]