fix code smells, test notebook

Этот коммит содержится в:
Inga Ulusoy 2023-03-31 11:21:08 +02:00
родитель ecc0d814bb
Коммит 0ae872e750
3 изменённых файлов: 64 добавлений и 53 удалений

Просмотреть файл

@ -113,7 +113,7 @@
" image_keys,\n",
" image_names,\n",
" features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
]
},
{
@ -128,7 +128,9 @@
"cell_type": "code",
"execution_count": null,
"id": "56c6d488-f093-4661-835a-5c73a329c874",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# (\n",
@ -319,7 +321,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},

Просмотреть файл

@ -238,7 +238,7 @@ class MultimodalSearch(AnalysisMethod):
}
for query in search_query:
if not (len(query) == 1) and (query in ("image", "text_input")):
if len(query) != 1 and (query in ("image", "text_input")):
raise SyntaxError(
'Each query must contain either an "image" or a "text_input"'
)
@ -444,26 +444,26 @@ class MultimodalSearch(AnalysisMethod):
)
return resized_image
def getAttMap(self, img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(
attMap, (img.shape[:2]), order=3, mode="constant"
def get_att_map(self, img, att_map, blur=True, overlap=True):
att_map -= att_map.min()
if att_map.max() > 0:
att_map /= att_map.max()
att_map = skimage_transform.resize(
att_map, (img.shape[:2]), order=3, mode="constant"
)
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
att_map -= att_map.min()
att_map /= att_map.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
att_mapV = cmap(att_map)
att_mapV = np.delete(att_mapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
att_map = (
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV
)
return attMap
return att_map
def upload_model_blip2_coco(self):
@ -566,7 +566,7 @@ class MultimodalSearch(AnalysisMethod):
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
for norm_img, grad_cam in zip(norm_imgs, gradcam):
avg_gradcam = MultimodalSearch.getAttMap(
avg_gradcam = MultimodalSearch.get_att_map(
self, norm_img, np.float32(grad_cam[0]), blur=True
)
local_avg_gradcams.append(avg_gradcam)

75
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -47,8 +47,8 @@
"outputs": [],
"source": [
"images = misinformation.utils.find_files(\n",
" path=\"../data/images/\",\n",
" limit=1000,\n",
" path=\"../data/Image_some_text/\",\n",
" limit=10,\n",
")"
]
},
@ -64,18 +64,6 @@
"mydict = misinformation.utils.initialize_dict(images)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict"
]
},
{
"cell_type": "markdown",
"id": "987540a8-d800-4c70-a76b-7bfabaf123fa",
@ -125,9 +113,7 @@
" image_keys,\n",
" image_names,\n",
" features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(\n",
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
")"
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
]
},
{
@ -181,15 +167,14 @@
"cell_type": "code",
"execution_count": null,
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"search_query3 = [\n",
" {\"text_input\": \"politician press conference\"},\n",
" {\"text_input\": \"a world map\"},\n",
" {\"image\": \"../data/haos.png\"},\n",
" {\"image\": \"../data/image-34098-800.png\"},\n",
" {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n",
" {\"text_input\": \"a dog\"},\n",
"]"
]
@ -209,7 +194,9 @@
"cell_type": "code",
"execution_count": null,
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
@ -237,10 +224,12 @@
"cell_type": "code",
"execution_count": null,
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict[\"100127S_ara\"]"
"mydict[\"109237S_spa\"]"
]
},
{
@ -255,12 +244,14 @@
"cell_type": "code",
"execution_count": null,
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ms.MultimodalSearch.show_results(\n",
" mydict,\n",
" search_query3[5],\n",
" search_query3[2],\n",
")"
]
},
@ -276,7 +267,9 @@
"cell_type": "code",
"execution_count": null,
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"itm_model = \"blip_base\"\n",
@ -288,7 +281,9 @@
"cell_type": "code",
"execution_count": null,
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
@ -314,7 +309,9 @@
"cell_type": "code",
"execution_count": null,
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ms.MultimodalSearch.show_results(\n",
@ -369,7 +366,9 @@
"cell_type": "code",
"execution_count": null,
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"df.head(10)"
@ -387,16 +386,26 @@
"cell_type": "code",
"execution_count": null,
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"df.to_csv(\"./data_out.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -410,7 +419,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.9.16"
}
},
"nbformat": 4,