зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-29 21:16:06 +02:00 
			
		
		
		
	add Model class
Этот коммит содержится в:
		
							родитель
							
								
									34afed5375
								
							
						
					
					
						Коммит
						bd63be4693
					
				| @ -1,5 +1,6 @@ | ||||
| from ammico.display import AnalysisExplorer | ||||
| from ammico.faces import EmotionDetector, ethical_disclosure | ||||
| from ammico.model import MultimodalSummaryModel | ||||
| from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure | ||||
| from ammico.utils import find_files, get_dataframe | ||||
| 
 | ||||
| @ -14,6 +15,7 @@ except ImportError: | ||||
| __all__ = [ | ||||
|     "AnalysisExplorer", | ||||
|     "EmotionDetector", | ||||
|     "MultimodalSummaryModel", | ||||
|     "TextDetector", | ||||
|     "TextAnalyzer", | ||||
|     "find_files", | ||||
|  | ||||
							
								
								
									
										111
									
								
								ammico/model.py
									
									
									
									
									
										Обычный файл
									
								
							
							
						
						
									
										111
									
								
								ammico/model.py
									
									
									
									
									
										Обычный файл
									
								
							| @ -0,0 +1,111 @@ | ||||
| import torch | ||||
| import warnings | ||||
| from transformers import ( | ||||
|     Qwen2_5_VLForConditionalGeneration, | ||||
|     AutoProcessor, | ||||
|     BitsAndBytesConfig, | ||||
|     AutoTokenizer, | ||||
| ) | ||||
| from typing import Optional | ||||
| 
 | ||||
| 
 | ||||
| class MultimodalSummaryModel: | ||||
|     DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct" | ||||
|     DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_id: Optional[str] = None, | ||||
|         device: Optional[str] = None, | ||||
|         cache_dir: Optional[str] = None, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Class for QWEN-2.5-VL model loading and inference. | ||||
|         Args: | ||||
|             model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu". | ||||
|             device: "cuda" or "cpu" (auto-detected when None). | ||||
|             cache_dir: huggingface cache dir (optional). | ||||
|         """ | ||||
|         self.device = self._resolve_device(device) | ||||
|         self.model_id = model_id or ( | ||||
|             self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL | ||||
|         ) | ||||
| 
 | ||||
|         self.cache_dir = cache_dir | ||||
|         self._trust_remote_code = True | ||||
|         self._quantize = True | ||||
| 
 | ||||
|         self.model = None | ||||
|         self.processor = None | ||||
|         self.tokenizer = None | ||||
| 
 | ||||
|         self._load_model_and_processor() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _resolve_device(device: Optional[str]) -> str: | ||||
|         if device is None: | ||||
|             return "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         if device.lower() not in ("cuda", "cpu"): | ||||
|             raise ValueError("device must be 'cuda' or 'cpu'") | ||||
|         if device.lower() == "cuda" and not torch.cuda.is_available(): | ||||
|             warnings.warn( | ||||
|                 "Although 'cuda' was requested, no CUDA device is available. Using CPU instead.", | ||||
|                 RuntimeWarning, | ||||
|                 stacklevel=2, | ||||
|             ) | ||||
|             return "cpu" | ||||
|         return device.lower() | ||||
| 
 | ||||
|     def _load_model_and_processor(self): | ||||
|         load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True} | ||||
|         if self.cache_dir: | ||||
|             load_kwargs["cache_dir"] = self.cache_dir | ||||
| 
 | ||||
|         self.processor = AutoProcessor.from_pretrained( | ||||
|             self.model_id, padding_side="left", **load_kwargs | ||||
|         ) | ||||
|         self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs) | ||||
| 
 | ||||
|         if self.device == "cuda": | ||||
|             compute_dtype = ( | ||||
|                 torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | ||||
|             ) | ||||
|             bnb_config = BitsAndBytesConfig( | ||||
|                 load_in_4bit=True, | ||||
|                 bnb_4bit_use_double_quant=True, | ||||
|                 bnb_4bit_quant_type="nf4", | ||||
|                 bnb_4bit_compute_dtype=compute_dtype, | ||||
|             ) | ||||
|             load_kwargs["quantization_config"] = bnb_config | ||||
|             load_kwargs["device_map"] = "auto" | ||||
| 
 | ||||
|         else: | ||||
|             load_kwargs.pop("quantization_config", None) | ||||
|             load_kwargs.pop("device_map", None) | ||||
| 
 | ||||
|         self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | ||||
|             self.model_id, **load_kwargs | ||||
|         ) | ||||
|         self.model.eval() | ||||
| 
 | ||||
|     def _close(self) -> None: | ||||
|         """Free model resources (helpful in long-running processes).""" | ||||
|         try: | ||||
|             if self.model is not None: | ||||
|                 del self.model | ||||
|                 self.model = None | ||||
|         finally: | ||||
|             try: | ||||
|                 if torch.cuda.is_available(): | ||||
|                     torch.cuda.empty_cache() | ||||
|             except Exception as e: | ||||
|                 warnings.warn( | ||||
|                     "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: " | ||||
|                     f"{e!r}", | ||||
|                     RuntimeWarning, | ||||
|                     stacklevel=2, | ||||
|                 ) | ||||
| 
 | ||||
|     def close(self) -> None: | ||||
|         """Free model resources (helpful in long-running processes).""" | ||||
|         self._close() | ||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 DimasfromLavoisier
						DimasfromLavoisier