зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	optimize validation of analysis type
Этот коммит содержится в:
		
							родитель
							
								
									d1a4954669
								
							
						
					
					
						Коммит
						ece132fe14
					
				| @ -6,7 +6,7 @@ import torch | |||||||
| from PIL import Image | from PIL import Image | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| from typing import List, Optional, Union, Dict, Any | from typing import List, Optional, Union, Dict, Any, Tuple | ||||||
| from collections.abc import Sequence as _Sequence | from collections.abc import Sequence as _Sequence | ||||||
| from transformers import GenerationConfig | from transformers import GenerationConfig | ||||||
| from qwen_vl_utils import process_vision_info | from qwen_vl_utils import process_vision_info | ||||||
| @ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return inputs |         return inputs | ||||||
| 
 | 
 | ||||||
|  |     def _validate_analysis_type( | ||||||
|  |         self, | ||||||
|  |         analysis_type: Union["AnalysisType", str], | ||||||
|  |         list_of_questions: Optional[List[str]], | ||||||
|  |         max_questions_per_image: int, | ||||||
|  |     ) -> Tuple[str, List[str], bool, bool]: | ||||||
|  |         if isinstance(analysis_type, AnalysisType): | ||||||
|  |             analysis_type = analysis_type.value | ||||||
|  | 
 | ||||||
|  |         allowed = {"summary", "questions", "summary_and_questions"} | ||||||
|  |         if analysis_type not in allowed: | ||||||
|  |             raise ValueError(f"analysis_type must be one of {allowed}") | ||||||
|  | 
 | ||||||
|  |         if list_of_questions is None: | ||||||
|  |             list_of_questions = [ | ||||||
|  |                 "Are there people in the image?", | ||||||
|  |                 "What is this picture about?", | ||||||
|  |             ] | ||||||
|  | 
 | ||||||
|  |         if analysis_type in ("questions", "summary_and_questions"): | ||||||
|  |             if len(list_of_questions) > max_questions_per_image: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |         is_summary = analysis_type in ("summary", "summary_and_questions") | ||||||
|  |         is_questions = analysis_type in ("questions", "summary_and_questions") | ||||||
|  | 
 | ||||||
|  |         return analysis_type, list_of_questions, is_summary, is_questions | ||||||
|  | 
 | ||||||
|     def analyse_images( |     def analyse_images( | ||||||
|         self, |         self, | ||||||
|         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, |         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, | ||||||
| @ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod): | |||||||
|             self.subdict (dict): dictionary with analysis results. |             self.subdict (dict): dictionary with analysis results. | ||||||
|         """ |         """ | ||||||
|         # TODO: add option to ask multiple questions per image as one batch. |         # TODO: add option to ask multiple questions per image as one batch. | ||||||
|         if isinstance(analysis_type, AnalysisType): |         analysis_type, list_of_questions, is_summary, is_questions = ( | ||||||
|             analysis_type = analysis_type.value |             self._validate_analysis_type( | ||||||
| 
 |                 analysis_type, list_of_questions, max_questions_per_image | ||||||
|         allowed = {"summary", "questions", "summary_and_questions"} |             ) | ||||||
|         if analysis_type not in allowed: |         ) | ||||||
|             raise ValueError(f"analysis_type must be one of {allowed}") |  | ||||||
| 
 |  | ||||||
|         if list_of_questions is None: |  | ||||||
|             list_of_questions = [ |  | ||||||
|                 "Are there people in the image?", |  | ||||||
|                 "What is this picture about?", |  | ||||||
|             ] |  | ||||||
| 
 | 
 | ||||||
|         keys = list(self.subdict.keys()) |         keys = list(self.subdict.keys()) | ||||||
|         for batch_start in range(0, len(keys), keys_batch_size): |         for batch_start in range(0, len(keys), keys_batch_size): | ||||||
|             batch_keys = keys[batch_start : batch_start + keys_batch_size] |             batch_keys = keys[batch_start : batch_start + keys_batch_size] | ||||||
|             for key in batch_keys: |             for key in batch_keys: | ||||||
|                 entry = self.subdict[key] |                 entry = self.subdict[key] | ||||||
|                 if analysis_type in ("summary", "summary_and_questions"): |                 if is_summary: | ||||||
|                     try: |                     try: | ||||||
|                         caps = self.generate_caption( |                         caps = self.generate_caption( | ||||||
|                             entry, |                             entry, | ||||||
| @ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod): | |||||||
|                             "Caption generation failed for key %s: %s", key, e |                             "Caption generation failed for key %s: %s", key, e | ||||||
|                         ) |                         ) | ||||||
| 
 | 
 | ||||||
|                 if analysis_type in ("questions", "summary_and_questions"): |                 if is_questions: | ||||||
|                     if len(list_of_questions) > max_questions_per_image: |  | ||||||
|                         raise ValueError( |  | ||||||
|                             f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})." |  | ||||||
|                             " Reduce questions or increase max_questions_per_image." |  | ||||||
|                         ) |  | ||||||
|                     try: |                     try: | ||||||
|                         vqa_map = self.answer_questions( |                         vqa_map = self.answer_questions( | ||||||
|                             list_of_questions, entry, is_concise_answer |                             list_of_questions, entry, is_concise_answer | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 DimasfromLavoisier
						DimasfromLavoisier