зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
update deprecated torch.cuda.amp
Этот коммит содержится в:
родитель
731353dfdb
Коммит
8c26a8de5e
@ -8,7 +8,7 @@ from torchcodec.decoders import VideoDecoder
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
from ammico.utils import AnalysisMethod
|
||||
|
||||
from typing import List, Optional, Dict, Any, Generator, Tuple
|
||||
from typing import List, Dict, Any, Generator, Tuple
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
@ -229,7 +229,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**processor_inputs, generation_config=gen_conf
|
||||
)
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user