change absolute paths to get_path fixture in test_multimodal_search, added get_testdict fixture

Этот коммит содержится в:
Petr Andriushchenko 2023-04-18 17:00:31 +02:00
родитель 35f63f0b74
Коммит acf20a4f21
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
2 изменённых файлов: 37 добавлений и 31 удалений

Просмотреть файл

@ -16,3 +16,12 @@ def set_environ(request):
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json" mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
) )
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")) print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
@pytest.fixture
def get_testdict(get_path):
testdict = {
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
}
return testdict

Просмотреть файл

@ -5,22 +5,17 @@ import numpy
from torch import device, cuda from torch import device, cuda
import misinformation.multimodal_search as ms import misinformation.multimodal_search as ms
testdict = {
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
"IMG_2809": {"filename": "./test/data/IMG_2809.png"},
}
related_error = 1e-2 related_error = 1e-2
gpu_is_not_available = not cuda.is_available() gpu_is_not_available = not cuda.is_available()
cuda.empty_cache() cuda.empty_cache()
def test_read_img(): def test_read_img(get_testdict):
my_dict = {} my_dict = {}
test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"]) test_img = ms.MultimodalSearch.read_img(
my_dict, get_testdict["IMG_2746"]["filename"]
)
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
@ -354,6 +349,8 @@ def test_parsing_images(
pre_extracted_feature_text, pre_extracted_feature_text,
pre_simularity, pre_simularity,
pre_sorted, pre_sorted,
get_path,
get_testdict,
tmp_path, tmp_path,
): ):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device ms.MultimodalSearch.multimodal_device = pre_multimodal_device
@ -365,7 +362,7 @@ def test_parsing_images(
_, _,
features_image_stacked, features_image_stacked,
) = ms.MultimodalSearch.parsing_images( ) = ms.MultimodalSearch.parsing_images(
testdict, pre_model, path_to_saved_tensors=tmp_path get_testdict, pre_model, path_to_saved_tensors=tmp_path
) )
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
@ -374,7 +371,7 @@ def test_parsing_images(
is True is True
) )
test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB") test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
test_querry = ( test_querry = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets." "The bird sat on a tree located at the intersection of 23rd and 43rd streets."
) )
@ -390,10 +387,10 @@ def test_parsing_images(
search_query = [ search_query = [
{"text_input": test_querry}, {"text_input": test_querry},
{"image": testdict["IMG_2746"]["filename"]}, {"image": get_testdict["IMG_2746"]["filename"]},
] ]
multi_features_stacked = ms.MultimodalSearch.querys_processing( multi_features_stacked = ms.MultimodalSearch.querys_processing(
testdict, search_query, model, txt_processor, vis_processor, pre_model get_testdict, search_query, model, txt_processor, vis_processor, pre_model
) )
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
@ -410,11 +407,11 @@ def test_parsing_images(
search_query2 = [ search_query2 = [
{"text_input": "A bus"}, {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, {"image": get_path + "IMG_3758.png"},
] ]
similarity, sorted_list = ms.MultimodalSearch.multimodal_search( similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
testdict, get_testdict,
model, model,
vis_processor, vis_processor,
txt_processor, txt_processor,
@ -448,26 +445,26 @@ def test_parsing_images(
@pytest.mark.long @pytest.mark.long
def test_itm(): def test_itm(get_path):
test_my_dict = { test_my_dict = {
"IMG_2746": { "IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png", "filename": get_path + "IMG_2746.png",
"rank A bus": 1, "rank A bus": 1,
"A bus": 0.15640679001808167, "A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1, "rank " + get_path + "IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979, get_path + "IMG_3758.png": 0.7533495426177979,
}, },
"IMG_2809": { "IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png", "filename": get_path + "IMG_2809.png",
"rank A bus": 0, "rank A bus": 0,
"A bus": 0.1970970332622528, "A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0, "rank " + get_path + "IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851, get_path + "IMG_3758.png": 0.8907483816146851,
}, },
} }
search_query3 = [ search_query3 = [
{"text_input": "A bus"}, {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, {"image": get_path + "IMG_3758.png"},
] ]
image_keys = ["IMG_2746", "IMG_2809"] image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]] sorted_list = [[1, 0], [1, 0]]
@ -509,26 +506,26 @@ def test_itm():
@pytest.mark.long @pytest.mark.long
def test_itm_blip2_coco(): def test_itm_blip2_coco(get_path):
test_my_dict = { test_my_dict = {
"IMG_2746": { "IMG_2746": {
"filename": "../misinformation/test/data/IMG_2746.png", "filename": get_path + "IMG_2746.png",
"rank A bus": 1, "rank A bus": 1,
"A bus": 0.15640679001808167, "A bus": 0.15640679001808167,
"rank ../misinformation/test/data/IMG_3758.png": 1, "rank " + get_path + "IMG_3758.png": 1,
"../misinformation/test/data/IMG_3758.png": 0.7533495426177979, get_path + "IMG_3758.png": 0.7533495426177979,
}, },
"IMG_2809": { "IMG_2809": {
"filename": "../misinformation/test/data/IMG_2809.png", "filename": get_path + "IMG_2809.png",
"rank A bus": 0, "rank A bus": 0,
"A bus": 0.1970970332622528, "A bus": 0.1970970332622528,
"rank ../misinformation/test/data/IMG_3758.png": 0, "rank " + get_path + "IMG_3758.png": 0,
"../misinformation/test/data/IMG_3758.png": 0.8907483816146851, get_path + "IMG_3758.png": 0.8907483816146851,
}, },
} }
search_query3 = [ search_query3 = [
{"text_input": "A bus"}, {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, {"image": get_path + "IMG_3758.png"},
] ]
image_keys = ["IMG_2746", "IMG_2809"] image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]] sorted_list = [[1, 0], [1, 0]]