зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
after merge
Этот коммит содержится в:
родитель
fc04ee12d8
Коммит
e34adf0530
@ -374,10 +374,9 @@ class MultimodalSearch(AnalysisMethod):
|
||||
tokenized_text,
|
||||
block_num=6,
|
||||
):
|
||||
if itm_model_type != "blip2_coco":
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
|
||||
output = model(
|
||||
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||
@ -493,15 +492,23 @@ class MultimodalSearch(AnalysisMethod):
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
):
|
||||
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||
raise SyntaxError(
|
||||
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||
)
|
||||
|
||||
choose_model = {
|
||||
"blip_base": MultimodalSearch.upload_model_blip_base,
|
||||
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||
}
|
||||
itm_model, vis_processor = choose_model[itm_model_type](self)
|
||||
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
|
||||
text_processor = load_processor("blip_caption")
|
||||
tokenizer = BlipBase.init_tokenizer()
|
||||
|
||||
if itm_model_type == "blip2_coco":
|
||||
need_grad_cam = False
|
||||
|
||||
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||
|
||||
avg_gradcams = []
|
||||
@ -524,7 +531,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||
current_len = len(filenames_in_batch)
|
||||
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||
self, filenames_in_batch, vis_processor
|
||||
self, filenames_in_batch, vis_processor_itm
|
||||
)
|
||||
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||
|
||||
@ -370,20 +370,55 @@ sorted_clip_vitl14_336 = [
|
||||
[6, 5, 4, 10, 2, 3, 0, 1, 11, 7, 9, 8],
|
||||
]
|
||||
|
||||
itm_scores_for_blib_base = [
|
||||
0.09277121722698212,
|
||||
0.020782141014933586,
|
||||
0.020832309499382973,
|
||||
0.004225197248160839,
|
||||
0.00020702541223727167,
|
||||
0.00410003075376153,
|
||||
0.0009893759852275252,
|
||||
0.00015318581426981837,
|
||||
1.9936736862291582e-05,
|
||||
4.0083985368255526e-05,
|
||||
0.0006117734010331333,
|
||||
4.1486648115096614e-05,
|
||||
]
|
||||
dict_itm_scores_for_blib = {
|
||||
"blip_base": [
|
||||
0.07107225805521011,
|
||||
0.02078203856945038,
|
||||
0.02083236537873745,
|
||||
0.0042252070270478725,
|
||||
0.0002070252230623737,
|
||||
0.004100032616406679,
|
||||
0.0009893750539049506,
|
||||
0.00015318625082727522,
|
||||
1.9936736862291582e-05,
|
||||
4.0084025386022404e-05,
|
||||
0.0006117739249020815,
|
||||
4.1486648115096614e-05,
|
||||
],
|
||||
"blip_large": [
|
||||
0.07890705019235611,
|
||||
0.04954551160335541,
|
||||
0.05564938113093376,
|
||||
0.002710158471018076,
|
||||
0.0026644798927009106,
|
||||
0.01277624536305666,
|
||||
0.003585426602512598,
|
||||
0.0019450040999799967,
|
||||
0.0036240608897060156,
|
||||
0.0013280785642564297,
|
||||
0.015366943553090096,
|
||||
0.0030039174016565084,
|
||||
],
|
||||
"blip2_coco": [
|
||||
0.0833505243062973,
|
||||
0.046232130378484726,
|
||||
0.04996354877948761,
|
||||
0.004187352955341339,
|
||||
2.5233526685042307e-05,
|
||||
0.002679687924683094,
|
||||
2.4826533262967132e-05,
|
||||
5.1878203521482646e-05,
|
||||
1.3434584616334178e-05,
|
||||
9.76747560343938e-06,
|
||||
7.34204331820365e-06,
|
||||
1.1423194337112363e-05,
|
||||
],
|
||||
}
|
||||
|
||||
dict_image_gradcam_with_itm_for_blip = {
|
||||
"blip_base": [125.12124404, 132.07243145, 65.43589668],
|
||||
"blip_large": [118.75610679, 125.35366997, 69.63849807],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -630,7 +665,7 @@ def test_parsing_images(
|
||||
cuda.empty_cache()
|
||||
|
||||
if pre_model == "blip":
|
||||
for itm_model in ["blip_base", "blip_large", "blip2_coco"]:
|
||||
for itm_model in ["blip_base","blip_large","blip2_coco"]:
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
@ -641,14 +676,13 @@ def test_parsing_images(
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=True,
|
||||
need_grad_cam=False,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(itm_scores_for_blib_base)), itm_scores_for_blib_base
|
||||
range(len(dict_itm_scores_for_blib[itm_model])),
|
||||
dict_itm_scores_for_blib[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(
|
||||
itm_scores[0].tolist()[i], itm, rel_tol=100 * related_error
|
||||
)
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=related_error)
|
||||
is True
|
||||
)
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user