diff --git a/docs/source/notebooks/Example multimodal.ipynb b/docs/source/notebooks/Example multimodal.ipynb index c091b84..2c943f5 100644 --- a/docs/source/notebooks/Example multimodal.ipynb +++ b/docs/source/notebooks/Example multimodal.ipynb @@ -113,7 +113,7 @@ " image_keys,\n", " image_names,\n", " features_image_stacked,\n", - ") = ms.MultimodalSearch.parsing_images(mydict, model_type)" + ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" ] }, { @@ -128,7 +128,9 @@ "cell_type": "code", "execution_count": null, "id": "56c6d488-f093-4661-835a-5c73a329c874", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# (\n", @@ -319,7 +321,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 531f977..3212d7c 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -238,7 +238,7 @@ class MultimodalSearch(AnalysisMethod): } for query in search_query: - if not (len(query) == 1) and (query in ("image", "text_input")): + if len(query) != 1 and (query in ("image", "text_input")): raise SyntaxError( 'Each query must contain either an "image" or a "text_input"' ) @@ -444,26 +444,26 @@ class MultimodalSearch(AnalysisMethod): ) return resized_image - def getAttMap(self, img, attMap, blur=True, overlap=True): - attMap -= attMap.min() - if attMap.max() > 0: - attMap /= attMap.max() - attMap = skimage_transform.resize( - attMap, (img.shape[:2]), order=3, mode="constant" + def get_att_map(self, img, att_map, blur=True, overlap=True): + att_map -= att_map.min() + if att_map.max() > 0: + att_map /= att_map.max() + att_map = skimage_transform.resize( + att_map, (img.shape[:2]), order=3, mode="constant" ) if blur: - attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) - attMap -= attMap.min() - attMap /= attMap.max() + att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2])) + att_map -= att_map.min() + att_map /= att_map.max() cmap = plt.get_cmap("jet") - attMapV = cmap(attMap) - attMapV = np.delete(attMapV, 3, 2) + att_mapV = cmap(att_map) + att_mapV = np.delete(att_mapV, 3, 2) if overlap: - attMap = ( - 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img - + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + att_map = ( + 1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img + + (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV ) - return attMap + return att_map def upload_model_blip2_coco(self): @@ -566,7 +566,7 @@ class MultimodalSearch(AnalysisMethod): norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] for norm_img, grad_cam in zip(norm_imgs, gradcam): - avg_gradcam = MultimodalSearch.getAttMap( + avg_gradcam = MultimodalSearch.get_att_map( self, norm_img, np.float32(grad_cam[0]), blur=True ) local_avg_gradcams.append(avg_gradcam) diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb index 407dae7..c49096c 100644 --- a/notebooks/multimodal_search.ipynb +++ b/notebooks/multimodal_search.ipynb @@ -47,8 +47,8 @@ "outputs": [], "source": [ "images = misinformation.utils.find_files(\n", - " path=\"../data/images/\",\n", - " limit=1000,\n", + " path=\"../data/Image_some_text/\",\n", + " limit=10,\n", ")" ] }, @@ -64,18 +64,6 @@ "mydict = misinformation.utils.initialize_dict(images)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "d98b6227-886d-41b8-a377-896dd8ab3c2a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "mydict" - ] - }, { "cell_type": "markdown", "id": "987540a8-d800-4c70-a76b-7bfabaf123fa", @@ -125,9 +113,7 @@ " image_keys,\n", " image_names,\n", " features_image_stacked,\n", - ") = ms.MultimodalSearch.parsing_images(\n", - " mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n", - ")" + ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" ] }, { @@ -181,15 +167,14 @@ "cell_type": "code", "execution_count": null, "id": "c4196a52-d01e-42e4-8674-5712f7d6f792", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "search_query3 = [\n", " {\"text_input\": \"politician press conference\"},\n", " {\"text_input\": \"a world map\"},\n", - " {\"image\": \"../data/haos.png\"},\n", - " {\"image\": \"../data/image-34098-800.png\"},\n", - " {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n", " {\"text_input\": \"a dog\"},\n", "]" ] @@ -209,7 +194,9 @@ "cell_type": "code", "execution_count": null, "id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n", @@ -237,10 +224,12 @@ "cell_type": "code", "execution_count": null, "id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "mydict[\"100127S_ara\"]" + "mydict[\"109237S_spa\"]" ] }, { @@ -255,12 +244,14 @@ "cell_type": "code", "execution_count": null, "id": "4324e4fd-e9aa-4933-bb12-074d54e0c510", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ms.MultimodalSearch.show_results(\n", " mydict,\n", - " search_query3[5],\n", + " search_query3[2],\n", ")" ] }, @@ -276,7 +267,9 @@ "cell_type": "code", "execution_count": null, "id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "itm_model = \"blip_base\"\n", @@ -288,7 +281,9 @@ "cell_type": "code", "execution_count": null, "id": "caf1f4ae-4b37-4954-800e-7120f0419de5", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n", @@ -314,7 +309,9 @@ "cell_type": "code", "execution_count": null, "id": "6a829b99-5230-463a-8b11-30ffbb67fc3a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ms.MultimodalSearch.show_results(\n", @@ -369,7 +366,9 @@ "cell_type": "code", "execution_count": null, "id": "e78646d6-80be-4d3e-8123-3360957bcaa8", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "df.head(10)" @@ -387,16 +386,26 @@ "cell_type": "code", "execution_count": null, "id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "df.to_csv(\"./data_out.csv\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -410,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" + "version": "3.9.16" } }, "nbformat": 4,