зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 13:36:04 +02:00 
			
		
		
		
	fixed new model in summary.py
Этот коммит содержится в:
		
							родитель
							
								
									28014f563f
								
							
						
					
					
						Коммит
						f4e47c105e
					
				| @ -15,7 +15,7 @@ class SummaryDetector(AnalysisMethod): | ||||
| 
 | ||||
|     summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| 
 | ||||
|     def load_model_base(self): | ||||
|     def load_model_base(): | ||||
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||
|             name="blip_caption", | ||||
|             model_type="base_coco", | ||||
| @ -24,7 +24,7 @@ class SummaryDetector(AnalysisMethod): | ||||
|         ) | ||||
|         return summary_model, summary_vis_processors | ||||
| 
 | ||||
|     def load_model_large(self): | ||||
|     def load_model_large(): | ||||
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||
|             name="blip_caption", | ||||
|             model_type="large_coco", | ||||
| @ -33,6 +33,14 @@ class SummaryDetector(AnalysisMethod): | ||||
|         ) | ||||
|         return summary_model, summary_vis_processors | ||||
| 
 | ||||
|     def load_model(model_type): | ||||
|         select_model = { | ||||
|             "base": SummaryDetector.load_model_base, | ||||
|             "large": SummaryDetector.load_model_large, | ||||
|         } | ||||
|         summary_model, summary_vis_processors = select_model[model_type]() | ||||
|         return summary_model, summary_vis_processors | ||||
| 
 | ||||
|     def set_keys(self) -> dict: | ||||
|         params = { | ||||
|             "const_image_summary": None, | ||||
| @ -40,13 +48,8 @@ class SummaryDetector(AnalysisMethod): | ||||
|         } | ||||
|         return params | ||||
| 
 | ||||
|     def analyse_image(self, model_type): | ||||
|     def analyse_image(self, summary_model, summary_vis_processors): | ||||
| 
 | ||||
|         select_model = { | ||||
|             "base": self.load_model_base, | ||||
|             "large": self.load_model_large, | ||||
|         } | ||||
|         summary_model, summary_vis_processors = select_model[model_type]() | ||||
|         path = self.subdict["filename"] | ||||
|         raw_image = Image.open(path).convert("RGB") | ||||
|         image = ( | ||||
| @ -73,13 +76,13 @@ class SummaryDetector(AnalysisMethod): | ||||
|         name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device | ||||
|     ) | ||||
| 
 | ||||
|     def analyse_questions(self, model_type, list_of_questions): | ||||
|     def analyse_questions(self, list_of_questions): | ||||
| 
 | ||||
|         if len(list_of_questions) > 0: | ||||
|             path = self.subdict["filename"] | ||||
|             raw_image = Image.open(path).convert("RGB") | ||||
|             image = ( | ||||
|                 summary_VQA_vis_processors["eval"](raw_image) | ||||
|                 self.summary_VQA_vis_processors["eval"](raw_image) | ||||
|                 .unsqueeze(0) | ||||
|                 .to(self.summary_device) | ||||
|             ) | ||||
|  | ||||
							
								
								
									
										17
									
								
								notebooks/image_summary.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
							
						
						
									
										17
									
								
								notebooks/image_summary.ipynb
									
									
									
										сгенерированный
									
									
									
								
							| @ -40,7 +40,7 @@ | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "images = mutils.find_files(\n", | ||||
|     "    path=\"../data/Image_some_text/\",\n", | ||||
|     "    path=\"../misinformation/test/data/\",\n", | ||||
|     "    limit=1000,\n", | ||||
|     ")" | ||||
|    ] | ||||
| @ -70,6 +70,15 @@ | ||||
|     "## Create captions for images and directly write to csv" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "summary_model, summary_vis_processors = sm.SummaryDetector.load_model(\"base\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
| @ -77,7 +86,9 @@ | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "for key in mydict:\n", | ||||
|     "    mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image()" | ||||
|     "    mydict[key] = sm.SummaryDetector(mydict[key]).analyse_image(\n", | ||||
|     "        summary_model, summary_vis_processors\n", | ||||
|     "    )" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
| @ -260,7 +271,7 @@ | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.9.5" | ||||
|    "version": "3.10.8" | ||||
|   }, | ||||
|   "vscode": { | ||||
|    "interpreter": { | ||||
|  | ||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko