diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 88836d9..666c373 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -8,7 +8,7 @@ from torchcodec.decoders import VideoDecoder from ammico.model import MultimodalSummaryModel from ammico.utils import AnalysisMethod -from typing import List, Optional, Dict, Any, Generator, Tuple +from typing import List, Dict, Any, Generator, Tuple from transformers import GenerationConfig @@ -229,7 +229,7 @@ class VideoSummaryDetector(AnalysisMethod): with torch.inference_mode(): try: if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): generated_ids = self.summary_model.model.generate( **processor_inputs, generation_config=gen_conf )