зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 21:46:04 +02:00
added test, fixed dependencies
Этот коммит содержится в:
родитель
b0cfab05e9
Коммит
18ecf4888b
@ -28,7 +28,7 @@ def test_read_img():
|
|||||||
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not cuda.is_available(), reason="model for gpu only")
|
@pytest.mark.skipif(gpu_is_not_available, reason="model for gpu only")
|
||||||
def test_load_feature_extractor_model_blip2():
|
def test_load_feature_extractor_model_blip2():
|
||||||
my_dict = {}
|
my_dict = {}
|
||||||
multimodal_device = device("cuda" if cuda.is_available() else "cpu")
|
multimodal_device = device("cuda" if cuda.is_available() else "cpu")
|
||||||
@ -46,14 +46,12 @@ def test_load_feature_extractor_model_blip2():
|
|||||||
processed_pic = vis_processor["eval"](test_pic).unsqueeze(0).to(multimodal_device)
|
processed_pic = vis_processor["eval"](test_pic).unsqueeze(0).to(multimodal_device)
|
||||||
processed_text = txt_processor["eval"](test_querry)
|
processed_text = txt_processor["eval"](test_querry)
|
||||||
|
|
||||||
with no_grad():
|
extracted_feature_img = model.extract_features(
|
||||||
with cuda.amp.autocast(enabled=(device != device("cpu"))):
|
{"image": processed_pic, "text_input": ""}, mode="image"
|
||||||
extracted_feature_img = model.extract_features(
|
)
|
||||||
{"image": processed_pic, "text_input": ""}, mode="image"
|
extracted_feature_text = model.extract_features(
|
||||||
)
|
{"image": "", "text_input": processed_text}, mode="text"
|
||||||
extracted_feature_text = model.extract_features(
|
)
|
||||||
{"image": "", "text_input": processed_text}, mode="text"
|
|
||||||
)
|
|
||||||
check_list_processed_pic = [
|
check_list_processed_pic = [
|
||||||
-1.0039474964141846,
|
-1.0039474964141846,
|
||||||
-1.0039474964141846,
|
-1.0039474964141846,
|
||||||
@ -122,10 +120,13 @@ def test_load_feature_extractor_model_blip2():
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_paths = [TEST_IMAGE_2, TEST_IMAGE_3]
|
image_paths = [TEST_IMAGE_2, TEST_IMAGE_3]
|
||||||
|
|
||||||
raw_images, images_tensors = ms.MultimodalSearch.read_and_process_images(
|
raw_images, images_tensors = ms.MultimodalSearch.read_and_process_images(
|
||||||
my_dict, image_paths, vis_processor
|
my_dict, image_paths, vis_processor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert list(numpy.array(raw_images[0])[257][34]) == [70, 66, 63]
|
||||||
|
|
||||||
check_list_images_tensors = [
|
check_list_images_tensors = [
|
||||||
-1.0039474964141846,
|
-1.0039474964141846,
|
||||||
-1.0039474964141846,
|
-1.0039474964141846,
|
||||||
@ -657,3 +658,138 @@ def test_load_feature_extractor_model_clip_vitl14_336(multimodal_device):
|
|||||||
|
|
||||||
del model, vis_processor, txt_processor
|
del model, vis_processor, txt_processor
|
||||||
cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
model_type = "blip"
|
||||||
|
# model_type = "blip2"
|
||||||
|
# model_type = "albef"
|
||||||
|
# model_type = "clip_base"
|
||||||
|
# model_type = "clip_vitl14"
|
||||||
|
# model_type = "clip_vitl14_336"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
(
|
||||||
|
"pre_multimodal_device",
|
||||||
|
"pre_model",
|
||||||
|
"pre_proc_pic",
|
||||||
|
"pre_proc_text",
|
||||||
|
"pre_extracted_feature_img",
|
||||||
|
"pre_extracted_feature_text",
|
||||||
|
),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
device("cuda"),
|
||||||
|
"blip2",
|
||||||
|
[
|
||||||
|
-1.0039474964141846,
|
||||||
|
-1.0039474964141846,
|
||||||
|
-0.8433647751808167,
|
||||||
|
-0.6097899675369263,
|
||||||
|
-0.5951915383338928,
|
||||||
|
-0.6243883967399597,
|
||||||
|
-0.6827820539474487,
|
||||||
|
-0.6097899675369263,
|
||||||
|
-0.7119789123535156,
|
||||||
|
-1.0623412132263184,
|
||||||
|
],
|
||||||
|
"the bird sat on a tree located at the intersection of 23rd and 43rd streets",
|
||||||
|
[
|
||||||
|
0.04566730558872223,
|
||||||
|
-0.042554520070552826,
|
||||||
|
-0.06970272958278656,
|
||||||
|
-0.009771779179573059,
|
||||||
|
0.01446065679192543,
|
||||||
|
0.10173682868480682,
|
||||||
|
0.007092420011758804,
|
||||||
|
-0.020045937970280647,
|
||||||
|
0.12923966348171234,
|
||||||
|
0.006452132016420364,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.1384204626083374,
|
||||||
|
-0.008662976324558258,
|
||||||
|
0.006269007455557585,
|
||||||
|
0.03151319921016693,
|
||||||
|
0.060558050870895386,
|
||||||
|
-0.03230040520429611,
|
||||||
|
0.015861615538597107,
|
||||||
|
-0.11856459826231003,
|
||||||
|
-0.058296192437410355,
|
||||||
|
0.03699290752410889,
|
||||||
|
],
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# (device("cpu"),"blip"),
|
||||||
|
# (device("cpu"),"albef"),
|
||||||
|
# (device("cpu"),"clip_base"),
|
||||||
|
# (device("cpu"),"clip_vitl14"),
|
||||||
|
# (device("cpu"),"clip_vitl14_336"),
|
||||||
|
# pytest.param( device("cuda"),"blip", marks=pytest.mark.skipif(gpu_is_not_available, reason="gpu_is_not_availible"),),
|
||||||
|
# pytest.param( device("cuda"),"albef", marks=pytest.mark.skipif(gpu_is_not_available, reason="gpu_is_not_availible"),),
|
||||||
|
# pytest.param( device("cuda"),"clip_base", marks=pytest.mark.skipif(gpu_is_not_available, reason="gpu_is_not_availible"),),
|
||||||
|
# pytest.param( device("cuda"),"clip_vitl14", marks=pytest.mark.skipif(gpu_is_not_available, reason="gpu_is_not_availible"),),
|
||||||
|
# pytest.param( device("cuda"),"clip_vitl14_336", marks=pytest.mark.skipif(gpu_is_not_available, reason="gpu_is_not_availible"),),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parsing_images(
|
||||||
|
pre_multimodal_device,
|
||||||
|
pre_model,
|
||||||
|
pre_proc_pic,
|
||||||
|
pre_proc_text,
|
||||||
|
pre_extracted_feature_img,
|
||||||
|
pre_extracted_feature_text,
|
||||||
|
):
|
||||||
|
mydict = {
|
||||||
|
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
|
||||||
|
"IMG_2750": {"filename": "./test/data/IMG_2750.png"},
|
||||||
|
}
|
||||||
|
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
vis_processor,
|
||||||
|
txt_processor,
|
||||||
|
image_keys,
|
||||||
|
image_names,
|
||||||
|
features_image_stacked,
|
||||||
|
) = ms.MultimodalSearch.parsing_images(mydict, pre_model)
|
||||||
|
|
||||||
|
for i, num in zip(range(10), features_image_stacked[0, 10:20].tolist()):
|
||||||
|
assert (
|
||||||
|
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
test_pic = Image.open(TEST_IMAGE_2).convert("RGB")
|
||||||
|
test_querry = (
|
||||||
|
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||||
|
)
|
||||||
|
processed_pic = (
|
||||||
|
vis_processor["eval"](test_pic).unsqueeze(0).to(pre_multimodal_device)
|
||||||
|
)
|
||||||
|
processed_text = txt_processor["eval"](test_querry)
|
||||||
|
|
||||||
|
for i, num in zip(range(10), processed_pic[0, 0, 0, 25:35].tolist()):
|
||||||
|
assert math.isclose(num, pre_proc_pic[i], rel_tol=related_error) is True
|
||||||
|
|
||||||
|
assert processed_text == pre_proc_text
|
||||||
|
|
||||||
|
search_query = [
|
||||||
|
{
|
||||||
|
"text_input": "The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
multi_features_stacked = ms.MultimodalSearch.querys_processing(
|
||||||
|
mydict, search_query, model, txt_processor, vis_processor, pre_model
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, num in zip(range(10), multi_features_stacked[0, 10:20].tolist()):
|
||||||
|
assert (
|
||||||
|
math.isclose(num, pre_extracted_feature_text[i], rel_tol=related_error)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
del model, vis_processor, txt_processor
|
||||||
|
cuda.empty_cache()
|
||||||
|
|||||||
@ -47,7 +47,7 @@ dependencies = [
|
|||||||
"spacytextblob",
|
"spacytextblob",
|
||||||
"textblob",
|
"textblob",
|
||||||
"torch",
|
"torch",
|
||||||
"salesforce-lavis @ git+https://github.com/salesforce/LAVIS.git@main",
|
"salesforce-lavis",
|
||||||
"bertopic",
|
"bertopic",
|
||||||
"grpcio",
|
"grpcio",
|
||||||
]
|
]
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user