Этот коммит содержится в:
Petr Andriushchenko 2023-03-28 16:35:46 +02:00
родитель fc04ee12d8
Коммит e34adf0530
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
2 изменённых файлов: 67 добавлений и 26 удалений

Просмотреть файл

@ -374,10 +374,9 @@ class MultimodalSearch(AnalysisMethod):
tokenized_text,
block_num=6,
):
if itm_model_type != "blip2_coco":
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model(
{"image": visual_input, "text_input": text_input}, match_head="itm"
@ -493,15 +492,23 @@ class MultimodalSearch(AnalysisMethod):
batch_size=1,
need_grad_cam=False,
):
if itm_model_type == "blip2_coco" and need_grad_cam is True:
raise SyntaxError(
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
)
choose_model = {
"blip_base": MultimodalSearch.upload_model_blip_base,
"blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
}
itm_model, vis_processor = choose_model[itm_model_type](self)
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer()
if itm_model_type == "blip2_coco":
need_grad_cam = False
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
avg_gradcams = []
@ -524,7 +531,7 @@ class MultimodalSearch(AnalysisMethod):
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
current_len = len(filenames_in_batch)
raw_images, images = MultimodalSearch.read_and_process_images_itm(
self, filenames_in_batch, vis_processor
self, filenames_in_batch, vis_processor_itm
)
queries_batch = [text_processor(query["text_input"])] * current_len
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(

Просмотреть файл

@ -370,20 +370,55 @@ sorted_clip_vitl14_336 = [
[6, 5, 4, 10, 2, 3, 0, 1, 11, 7, 9, 8],
]
itm_scores_for_blib_base = [
0.09277121722698212,
0.020782141014933586,
0.020832309499382973,
0.004225197248160839,
0.00020702541223727167,
0.00410003075376153,
0.0009893759852275252,
0.00015318581426981837,
1.9936736862291582e-05,
4.0083985368255526e-05,
0.0006117734010331333,
4.1486648115096614e-05,
]
dict_itm_scores_for_blib = {
"blip_base": [
0.07107225805521011,
0.02078203856945038,
0.02083236537873745,
0.0042252070270478725,
0.0002070252230623737,
0.004100032616406679,
0.0009893750539049506,
0.00015318625082727522,
1.9936736862291582e-05,
4.0084025386022404e-05,
0.0006117739249020815,
4.1486648115096614e-05,
],
"blip_large": [
0.07890705019235611,
0.04954551160335541,
0.05564938113093376,
0.002710158471018076,
0.0026644798927009106,
0.01277624536305666,
0.003585426602512598,
0.0019450040999799967,
0.0036240608897060156,
0.0013280785642564297,
0.015366943553090096,
0.0030039174016565084,
],
"blip2_coco": [
0.0833505243062973,
0.046232130378484726,
0.04996354877948761,
0.004187352955341339,
2.5233526685042307e-05,
0.002679687924683094,
2.4826533262967132e-05,
5.1878203521482646e-05,
1.3434584616334178e-05,
9.76747560343938e-06,
7.34204331820365e-06,
1.1423194337112363e-05,
],
}
dict_image_gradcam_with_itm_for_blip = {
"blip_base": [125.12124404, 132.07243145, 65.43589668],
"blip_large": [118.75610679, 125.35366997, 69.63849807],
}
@pytest.mark.parametrize(
@ -630,7 +665,7 @@ def test_parsing_images(
cuda.empty_cache()
if pre_model == "blip":
for itm_model in ["blip_base", "blip_large", "blip2_coco"]:
for itm_model in ["blip_base","blip_large","blip2_coco"]:
(
itm_scores,
image_gradcam_with_itm,
@ -641,14 +676,13 @@ def test_parsing_images(
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=True,
need_grad_cam=False,
)
for i, itm in zip(
range(len(itm_scores_for_blib_base)), itm_scores_for_blib_base
range(len(dict_itm_scores_for_blib[itm_model])),
dict_itm_scores_for_blib[itm_model],
):
assert (
math.isclose(
itm_scores[0].tolist()[i], itm, rel_tol=100 * related_error
)
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=related_error)
is True
)