зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	fixed new model in summary.py
Этот коммит содержится в:
		
							родитель
							
								
									28014f563f
								
							
						
					
					
						Коммит
						f4e47c105e
					
				| @ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |     summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
| 
 | 
 | ||||||
|     def load_model_base(self): |     def load_model_base(): | ||||||
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess( |         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||||
|             name="blip_caption", |             name="blip_caption", | ||||||
|             model_type="base_coco", |             model_type="base_coco", | ||||||
| @ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return summary_model, summary_vis_processors |         return summary_model, summary_vis_processors | ||||||
| 
 | 
 | ||||||
|     def load_model_large(self): |     def load_model_large(): | ||||||
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess( |         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||||
|             name="blip_caption", |             name="blip_caption", | ||||||
|             model_type="large_coco", |             model_type="large_coco", | ||||||
| @ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return summary_model, summary_vis_processors |         return summary_model, summary_vis_processors | ||||||
| 
 | 
 | ||||||
|  |     def load_model(model_type): | ||||||
|  |         select_model = { | ||||||
|  |             "base": SummaryDetector.load_model_base, | ||||||
|  |             "large": SummaryDetector.load_model_large, | ||||||
|  |         } | ||||||
|  |         summary_model, summary_vis_processors = select_model[model_type]() | ||||||
|  |         return summary_model, summary_vis_processors | ||||||
|  | 
 | ||||||
|     def set_keys(self) -> dict: |     def set_keys(self) -> dict: | ||||||
|         params = { |         params = { | ||||||
|             "const_image_summary": None, |             "const_image_summary": None, | ||||||
| @ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         } |         } | ||||||
|         return params |         return params | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self, model_type): |     def analyse_image(self, summary_model, summary_vis_processors): | ||||||
| 
 | 
 | ||||||
|         select_model = { |  | ||||||
|             "base": self.load_model_base, |  | ||||||
|             "large": self.load_model_large, |  | ||||||
|         } |  | ||||||
|         summary_model, summary_vis_processors = select_model[model_type]() |  | ||||||
|         path = self.subdict["filename"] |         path = self.subdict["filename"] | ||||||
|         raw_image = Image.open(path).convert("RGB") |         raw_image = Image.open(path).convert("RGB") | ||||||
|         image = ( |         image = ( | ||||||
| @ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device |         name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     def analyse_questions(self, model_type, list_of_questions): |     def analyse_questions(self, list_of_questions): | ||||||
| 
 | 
 | ||||||
|         if len(list_of_questions) > 0: |         if len(list_of_questions) > 0: | ||||||
|             path = self.subdict["filename"] |             path = self.subdict["filename"] | ||||||
|             raw_image = Image.open(path).convert("RGB") |             raw_image = Image.open(path).convert("RGB") | ||||||
|             image = ( |             image = ( | ||||||
|                 summary_VQA_vis_processors["eval"](raw_image) |                 self.summary_VQA_vis_processors["eval"](raw_image) | ||||||
|                 .unsqueeze(0) |                 .unsqueeze(0) | ||||||
|                 .to(self.summary_device) |                 .to(self.summary_device) | ||||||
|             ) |             ) | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								notebooks/image_summary.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
							
						
						
									
										17
									
								
								notebooks/image_summary.ipynb
									
									
									
										сгенерированный
									
									
									
								
							| @ -40,7 +40,7 @@ | |||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "images = mutils.find_files(\n", |     "images = mutils.find_files(\n", | ||||||
|     "    path=\"../data/Image_some_text/\",\n", |     "    path=\"../misinformation/test/data/\",\n", | ||||||
|     "    limit=1000,\n", |     "    limit=1000,\n", | ||||||
|     ")" |     ")" | ||||||
|    ] |    ] | ||||||
| @ -70,6 +70,15 @@ | |||||||
|     "## Create captions for images and directly write to csv" |     "## Create captions for images and directly write to csv" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
| @ -77,7 +86,9 @@ | |||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "for key in mydict:\n", |     "for key in mydict:\n", | ||||||
|     "    mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()" |     "    mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n", | ||||||
|  |     "        summary_model, summary_vis_processors\n", | ||||||
|  |     "    )" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -260,7 +271,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.9.5" |    "version": "3.10.8" | ||||||
|   }, |   }, | ||||||
|   "vscode": { |   "vscode": { | ||||||
|    "interpreter": { |    "interpreter": { | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko