зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
fix code smells, test notebook
Этот коммит содержится в:
родитель
ecc0d814bb
Коммит
0ae872e750
@ -113,7 +113,7 @@
|
||||
" image_keys,\n",
|
||||
" image_names,\n",
|
||||
" features_image_stacked,\n",
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -128,7 +128,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# (\n",
|
||||
@ -319,7 +321,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@ -238,7 +238,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
}
|
||||
|
||||
for query in search_query:
|
||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||
if len(query) != 1 and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each query must contain either an "image" or a "text_input"'
|
||||
)
|
||||
@ -444,26 +444,26 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def getAttMap(self, img, attMap, blur=True, overlap=True):
|
||||
attMap -= attMap.min()
|
||||
if attMap.max() > 0:
|
||||
attMap /= attMap.max()
|
||||
attMap = skimage_transform.resize(
|
||||
attMap, (img.shape[:2]), order=3, mode="constant"
|
||||
def get_att_map(self, img, att_map, blur=True, overlap=True):
|
||||
att_map -= att_map.min()
|
||||
if att_map.max() > 0:
|
||||
att_map /= att_map.max()
|
||||
att_map = skimage_transform.resize(
|
||||
att_map, (img.shape[:2]), order=3, mode="constant"
|
||||
)
|
||||
if blur:
|
||||
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||
attMap -= attMap.min()
|
||||
attMap /= attMap.max()
|
||||
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
|
||||
att_map -= att_map.min()
|
||||
att_map /= att_map.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
attMapV = cmap(attMap)
|
||||
attMapV = np.delete(attMapV, 3, 2)
|
||||
att_mapV = cmap(att_map)
|
||||
att_mapV = np.delete(att_mapV, 3, 2)
|
||||
if overlap:
|
||||
attMap = (
|
||||
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||
att_map = (
|
||||
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
|
||||
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapV
|
||||
)
|
||||
return attMap
|
||||
return att_map
|
||||
|
||||
def upload_model_blip2_coco(self):
|
||||
|
||||
@ -566,7 +566,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||
|
||||
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||
avg_gradcam = MultimodalSearch.getAttMap(
|
||||
avg_gradcam = MultimodalSearch.get_att_map(
|
||||
self, norm_img, np.float32(grad_cam[0]), blur=True
|
||||
)
|
||||
local_avg_gradcams.append(avg_gradcam)
|
||||
|
||||
75
notebooks/multimodal_search.ipynb
сгенерированный
75
notebooks/multimodal_search.ipynb
сгенерированный
@ -47,8 +47,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"images = misinformation.utils.find_files(\n",
|
||||
" path=\"../data/images/\",\n",
|
||||
" limit=1000,\n",
|
||||
" path=\"../data/Image_some_text/\",\n",
|
||||
" limit=10,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -64,18 +64,6 @@
|
||||
"mydict = misinformation.utils.initialize_dict(images)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "987540a8-d800-4c70-a76b-7bfabaf123fa",
|
||||
@ -125,9 +113,7 @@
|
||||
" image_keys,\n",
|
||||
" image_names,\n",
|
||||
" features_image_stacked,\n",
|
||||
") = ms.MultimodalSearch.parsing_images(\n",
|
||||
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
|
||||
")"
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -181,15 +167,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_query3 = [\n",
|
||||
" {\"text_input\": \"politician press conference\"},\n",
|
||||
" {\"text_input\": \"a world map\"},\n",
|
||||
" {\"image\": \"../data/haos.png\"},\n",
|
||||
" {\"image\": \"../data/image-34098-800.png\"},\n",
|
||||
" {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n",
|
||||
" {\"text_input\": \"a dog\"},\n",
|
||||
"]"
|
||||
]
|
||||
@ -209,7 +194,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
|
||||
@ -237,10 +224,12 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict[\"100127S_ara\"]"
|
||||
"mydict[\"109237S_spa\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -255,12 +244,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict,\n",
|
||||
" search_query3[5],\n",
|
||||
" search_query3[2],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -276,7 +267,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_model = \"blip_base\"\n",
|
||||
@ -288,7 +281,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
|
||||
@ -314,7 +309,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
@ -369,7 +366,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.head(10)"
|
||||
@ -387,16 +386,26 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.to_csv(\"./data_out.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -410,7 +419,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.0"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user