зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 13:36:04 +02:00
update deprecated torch.cuda.amp
Этот коммит содержится в:
родитель
731353dfdb
Коммит
8c26a8de5e
@ -8,7 +8,7 @@ from torchcodec.decoders import VideoDecoder
|
|||||||
from ammico.model import MultimodalSummaryModel
|
from ammico.model import MultimodalSummaryModel
|
||||||
from ammico.utils import AnalysisMethod
|
from ammico.utils import AnalysisMethod
|
||||||
|
|
||||||
from typing import List, Optional, Dict, Any, Generator, Tuple
|
from typing import List, Dict, Any, Generator, Tuple
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
if self.summary_model.device == "cuda":
|
if self.summary_model.device == "cuda":
|
||||||
with torch.cuda.amp.autocast(enabled=True):
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
generated_ids = self.summary_model.model.generate(
|
generated_ids = self.summary_model.model.generate(
|
||||||
**processor_inputs, generation_config=gen_conf
|
**processor_inputs, generation_config=gen_conf
|
||||||
)
|
)
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user