зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-31 05:56:05 +02:00 
			
		
		
		
	add model change function to summary
Этот коммит содержится в:
		
							родитель
							
								
									c208039b7c
								
							
						
					
					
						Коммит
						c136b91fba
					
				| @ -15,9 +15,23 @@ class SummaryDetector(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |     summary_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
| 
 | 
 | ||||||
|  |     def load_model_base(self): | ||||||
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess( |         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||||
|         name="blip_caption", model_type="base_coco", is_eval=True, device=summary_device |             name="blip_caption", | ||||||
|  |             model_type="base_coco", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=SummaryDetector.summary_device, | ||||||
|         ) |         ) | ||||||
|  |         return summary_model, summary_vis_processors | ||||||
|  | 
 | ||||||
|  |     def load_model_large(self): | ||||||
|  |         summary_model, summary_vis_processors, _ = load_model_and_preprocess( | ||||||
|  |             name="blip_caption", | ||||||
|  |             model_type="large_coco", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=SummaryDetector.summary_device, | ||||||
|  |         ) | ||||||
|  |         return summary_model, summary_vis_processors | ||||||
| 
 | 
 | ||||||
|     def set_keys(self) -> dict: |     def set_keys(self) -> dict: | ||||||
|         params = { |         params = { | ||||||
| @ -26,19 +40,27 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         } |         } | ||||||
|         return params |         return params | ||||||
| 
 | 
 | ||||||
|     def analyse_image(self): |     def analyse_image(self, model_type): | ||||||
| 
 | 
 | ||||||
|  |         select_model = { | ||||||
|  |             "base": self.load_model_base, | ||||||
|  |             "large": self.load_model_large, | ||||||
|  |         } | ||||||
|  |         summary_model, summary_vis_processors = select_model[model_type]() | ||||||
|         path = self.subdict["filename"] |         path = self.subdict["filename"] | ||||||
|         raw_image = Image.open(path).convert("RGB") |         raw_image = Image.open(path).convert("RGB") | ||||||
|         image = ( |         image = ( | ||||||
|             self.summary_vis_processors["eval"](raw_image) |             summary_vis_processors["eval"](raw_image) | ||||||
|             .unsqueeze(0) |             .unsqueeze(0) | ||||||
|             .to(self.summary_device) |             .to(self.summary_device) | ||||||
|         ) |         ) | ||||||
|         self.image_summary["const_image_summary"] = self.summary_model.generate( |         with torch.no_grad(): | ||||||
|  |             self.image_summary["const_image_summary"] = summary_model.generate( | ||||||
|                 {"image": image} |                 {"image": image} | ||||||
|             )[0] |             )[0] | ||||||
|         self.image_summary["3_non-deterministic summary"] = self.summary_model.generate( |             self.image_summary[ | ||||||
|  |                 "3_non-deterministic summary" | ||||||
|  |             ] = summary_model.generate( | ||||||
|                 {"image": image}, use_nucleus_sampling=True, num_captions=3 |                 {"image": image}, use_nucleus_sampling=True, num_captions=3 | ||||||
|             ) |             ) | ||||||
|         for key in self.image_summary: |         for key in self.image_summary: | ||||||
| @ -53,12 +75,13 @@ class SummaryDetector(AnalysisMethod): | |||||||
|         name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device |         name="blip_vqa", model_type="vqav2", is_eval=True, device=summary_device | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     def analyse_questions(self, list_of_questions): |     def analyse_questions(self, model_type, list_of_questions): | ||||||
|  | 
 | ||||||
|         if len(list_of_questions) > 0: |         if len(list_of_questions) > 0: | ||||||
|             path = self.subdict["filename"] |             path = self.subdict["filename"] | ||||||
|             raw_image = Image.open(path).convert("RGB") |             raw_image = Image.open(path).convert("RGB") | ||||||
|             image = ( |             image = ( | ||||||
|                 self.summary_VQA_vis_processors["eval"](raw_image) |                 summary_VQA_vis_processors["eval"](raw_image) | ||||||
|                 .unsqueeze(0) |                 .unsqueeze(0) | ||||||
|                 .to(self.summary_device) |                 .to(self.summary_device) | ||||||
|             ) |             ) | ||||||
| @ -68,6 +91,7 @@ class SummaryDetector(AnalysisMethod): | |||||||
|             batch_size = len(list_of_questions) |             batch_size = len(list_of_questions) | ||||||
|             image_batch = image.repeat(batch_size, 1, 1, 1) |             image_batch = image.repeat(batch_size, 1, 1, 1) | ||||||
| 
 | 
 | ||||||
|  |             with torch.no_grad(): | ||||||
|                 answers_batch = self.summary_VQA_model.predict_answers( |                 answers_batch = self.summary_VQA_model.predict_answers( | ||||||
|                     samples={"image": image_batch, "text_input": question_batch}, |                     samples={"image": image_batch, "text_input": question_batch}, | ||||||
|                     inference_method="generate", |                     inference_method="generate", | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko