fixed varible name, uncomment test, exluded it from CI, fixed error in multimodal_search

Этот коммит содержится в:
Petr Andriushchenko 2023-03-31 13:35:04 +02:00
родитель 0ae872e750
Коммит f1aeeabd18
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
4 изменённых файлов: 141 добавлений и 131 удалений

2
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -32,7 +32,7 @@ jobs:
- name: Run pytest - name: Run pytest
run: | run: |
cd misinformation cd misinformation
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml python -m pytest -m "not gcv mem_cons" -svv --cov=. --cov-report=xml
- name: Upload coverage - name: Upload coverage
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9' if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3

Просмотреть файл

@ -13,8 +13,6 @@ from IPython.display import display
from lavis.models import load_model_and_preprocess, load_model, BlipBase from lavis.models import load_model_and_preprocess, load_model, BlipBase
from lavis.processors import load_processor from lavis.processors import load_processor
# from memory_profiler import profile
class MultimodalSearch(AnalysisMethod): class MultimodalSearch(AnalysisMethod):
def __init__(self, subdict: dict) -> None: def __init__(self, subdict: dict) -> None:
@ -382,7 +380,6 @@ class MultimodalSearch(AnalysisMethod):
def compute_gradcam_batch( def compute_gradcam_batch(
self, self,
itm_model_type,
model, model,
visual_input, visual_input,
text_input, text_input,
@ -456,12 +453,12 @@ class MultimodalSearch(AnalysisMethod):
att_map -= att_map.min() att_map -= att_map.min()
att_map /= att_map.max() att_map /= att_map.max()
cmap = plt.get_cmap("jet") cmap = plt.get_cmap("jet")
att_mapV = cmap(att_map) att_mapv = cmap(att_map)
att_mapV = np.delete(att_mapV, 3, 2) att_mapv = np.delete(att_mapv, 3, 2)
if overlap: if overlap:
att_map = ( att_map = (
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img 1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV + (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
) )
return att_map return att_map
@ -498,7 +495,6 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=384) vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor return itm_model, vis_processor
# @profile
def image_text_match_reordering( def image_text_match_reordering(
self, self,
search_query, search_query,
@ -518,6 +514,7 @@ class MultimodalSearch(AnalysisMethod):
"blip_large": MultimodalSearch.upload_model_blip_large, "blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco, "blip2_coco": MultimodalSearch.upload_model_blip2_coco,
} }
itm_model, vis_processor_itm = choose_model[itm_model_type](self) itm_model, vis_processor_itm = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption") text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer() tokenizer = BlipBase.init_tokenizer()
@ -557,7 +554,6 @@ class MultimodalSearch(AnalysisMethod):
if need_grad_cam: if need_grad_cam:
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch( gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
self, self,
itm_model_type,
itm_model, itm_model,
images, images,
queries_batch, queries_batch,
@ -618,20 +614,20 @@ class MultimodalSearch(AnalysisMethod):
image_gradcam_with_itm[ image_gradcam_with_itm[
list(search_query[index_text_query].values())[0] list(search_query[index_text_query].values())[0]
] = localimage_gradcam_with_itm ] = localimage_gradcam_with_itm
del ( del (
itm_model, itm_model,
vis_processor_itm, vis_processor_itm,
text_processor, text_processor,
raw_images, raw_images,
images, images,
tokenizer, tokenizer,
queries_batch, queries_batch,
queries_tok_batch, queries_tok_batch,
itm_score, itm_score,
) )
if need_grad_cam: if need_grad_cam:
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
torch.cuda.empty_cache() torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False): def show_results(self, query, itm=False, image_gradcam_with_itm=False):

Просмотреть файл

@ -447,109 +447,111 @@ def test_parsing_images(
cuda.empty_cache() cuda.empty_cache()
# def test_itm(): @pytest.mark.long
# test_my_dict = { def test_itm():
# "IMG_2746": { test_my_dict = {
# "filename": "../misinformation/test/data/IMG_2746.png", "IMG_2746": {
# "rank A bus": 1, "filename": "../misinformation/test/data/IMG_2746.png",
# "A bus": 0.15640679001808167, "rank A bus": 1,
# "rank ../misinformation/test/data/IMG_3758.png": 1, "A bus": 0.15640679001808167,
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979, "rank ../misinformation/test/data/IMG_3758.png": 1,
# }, "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
# "IMG_2809": { },
# "filename": "../misinformation/test/data/IMG_2809.png", "IMG_2809": {
# "rank A bus": 0, "filename": "../misinformation/test/data/IMG_2809.png",
# "A bus": 0.1970970332622528, "rank A bus": 0,
# "rank ../misinformation/test/data/IMG_3758.png": 0, "A bus": 0.1970970332622528,
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851, "rank ../misinformation/test/data/IMG_3758.png": 0,
# }, "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
# } },
# search_query3 = [ }
# {"text_input": "A bus"}, search_query3 = [
# {"image": "../misinformation/test/data/IMG_3758.png"}, {"text_input": "A bus"},
# ] {"image": "../misinformation/test/data/IMG_3758.png"},
# image_keys = ["IMG_2746", "IMG_2809"] ]
# sorted_list = [[1, 0], [1, 0]] image_keys = ["IMG_2746", "IMG_2809"]
# for itm_model in ["blip_base", "blip_large"]: sorted_list = [[1, 0], [1, 0]]
# ( for itm_model in ["blip_base", "blip_large"]:
# itm_scores, (
# image_gradcam_with_itm, itm_scores,
# ) = ms.MultimodalSearch.image_text_match_reordering( image_gradcam_with_itm,
# test_my_dict, ) = ms.MultimodalSearch.image_text_match_reordering(
# search_query3, test_my_dict,
# itm_model, search_query3,
# image_keys, itm_model,
# sorted_list, image_keys,
# batch_size=1, sorted_list,
# need_grad_cam=True, batch_size=1,
# ) need_grad_cam=True,
# for i, itm in zip( )
# range(len(dict_itm_scores_for_blib[itm_model])), for i, itm in zip(
# dict_itm_scores_for_blib[itm_model], range(len(dict_itm_scores_for_blib[itm_model])),
# ): dict_itm_scores_for_blib[itm_model],
# assert ( ):
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) assert (
# is True math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
# ) is True
# for i, grad_cam in zip( )
# range(len(dict_image_gradcam_with_itm_for_blip[itm_model])), for i, grad_cam in zip(
# dict_image_gradcam_with_itm_for_blip[itm_model], range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
# ): dict_image_gradcam_with_itm_for_blip[itm_model],
# assert ( ):
# math.isclose( assert (
# image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i], math.isclose(
# grad_cam, image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
# rel_tol=10 * related_error, grad_cam,
# ) rel_tol=10 * related_error,
# is True )
# ) is True
# del itm_scores, image_gradcam_with_itm )
# cuda.empty_cache() del itm_scores, image_gradcam_with_itm
cuda.empty_cache()
# def test_itm_blip2_coco(): @pytest.mark.long
# test_my_dict = { def test_itm_blip2_coco():
# "IMG_2746": { test_my_dict = {
# "filename": "../misinformation/test/data/IMG_2746.png", "IMG_2746": {
# "rank A bus": 1, "filename": "../misinformation/test/data/IMG_2746.png",
# "A bus": 0.15640679001808167, "rank A bus": 1,
# "rank ../misinformation/test/data/IMG_3758.png": 1, "A bus": 0.15640679001808167,
# "../misinformation/test/data/IMG_3758.png": 0.7533495426177979, "rank ../misinformation/test/data/IMG_3758.png": 1,
# }, "../misinformation/test/data/IMG_3758.png": 0.7533495426177979,
# "IMG_2809": { },
# "filename": "../misinformation/test/data/IMG_2809.png", "IMG_2809": {
# "rank A bus": 0, "filename": "../misinformation/test/data/IMG_2809.png",
# "A bus": 0.1970970332622528, "rank A bus": 0,
# "rank ../misinformation/test/data/IMG_3758.png": 0, "A bus": 0.1970970332622528,
# "../misinformation/test/data/IMG_3758.png": 0.8907483816146851, "rank ../misinformation/test/data/IMG_3758.png": 0,
# }, "../misinformation/test/data/IMG_3758.png": 0.8907483816146851,
# } },
# search_query3 = [ }
# {"text_input": "A bus"}, search_query3 = [
# {"image": "../misinformation/test/data/IMG_3758.png"}, {"text_input": "A bus"},
# ] {"image": "../misinformation/test/data/IMG_3758.png"},
# image_keys = ["IMG_2746", "IMG_2809"] ]
# sorted_list = [[1, 0], [1, 0]] image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
# ( (
# itm_scores, itm_scores,
# image_gradcam_with_itm, image_gradcam_with_itm,
# ) = ms.MultimodalSearch.image_text_match_reordering( ) = ms.MultimodalSearch.image_text_match_reordering(
# test_my_dict, test_my_dict,
# search_query3, search_query3,
# "blip2_coco", "blip2_coco",
# image_keys, image_keys,
# sorted_list, sorted_list,
# batch_size=1, batch_size=1,
# need_grad_cam=False, need_grad_cam=False,
# ) )
# for i, itm in zip( for i, itm in zip(
# range(len(dict_itm_scores_for_blib["blip2_coco"])), range(len(dict_itm_scores_for_blib["blip2_coco"])),
# dict_itm_scores_for_blib["blip2_coco"], dict_itm_scores_for_blib["blip2_coco"],
# ): ):
# assert ( assert (
# math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
# is True is True
# ) )
# del itm_scores, image_gradcam_with_itm del itm_scores, image_gradcam_with_itm
# cuda.empty_cache() cuda.empty_cache()

22
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -47,7 +47,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"images = misinformation.utils.find_files(\n", "images = misinformation.utils.find_files(\n",
" path=\"../data/Image_some_text/\",\n", " path=\"../data/images/\",\n",
" limit=10,\n", " limit=10,\n",
")" ")"
] ]
@ -64,6 +64,18 @@
"mydict = misinformation.utils.initialize_dict(images)" "mydict = misinformation.utils.initialize_dict(images)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "c66aec87-ede7-4985-912e-3ca29245ebf2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "987540a8-d800-4c70-a76b-7bfabaf123fa", "id": "987540a8-d800-4c70-a76b-7bfabaf123fa",
@ -143,7 +155,7 @@
"# ) = ms.MultimodalSearch.parsing_images(\n", "# ) = ms.MultimodalSearch.parsing_images(\n",
"# mydict,\n", "# mydict,\n",
"# model_type,\n", "# model_type,\n",
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n", "# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
"# )" "# )"
] ]
}, },
@ -251,7 +263,7 @@
"source": [ "source": [
"ms.MultimodalSearch.show_results(\n", "ms.MultimodalSearch.show_results(\n",
" mydict,\n", " mydict,\n",
" search_query3[2],\n", " search_query3[0],\n",
")" ")"
] ]
}, },
@ -405,7 +417,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -419,7 +431,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.16" "version": "3.9.0"
} }
}, },
"nbformat": 4, "nbformat": 4,