From ae6cc075f6c0f0dc360f2998181f4f8cbebd3622 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 30 Jun 2024 12:14:47 +0900 Subject: [PATCH] Apply amp.autocast only when CUDA is available --- lama_cleaner/model/ldm.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index a5b6d12..7d265de 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -1,26 +1,30 @@ import os +from functools import wraps import numpy as np import torch -from loguru import logger +import torch.nn as nn +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.ddim_sampler import DDIMSampler from lama_cleaner.model.plms_sampler import PLMSSampler +from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding from lama_cleaner.schema import Config, LDMSampler -torch.manual_seed(42) -import torch.nn as nn -from lama_cleaner.helper import ( - download_model, - norm_img, - get_cache_path_by_url, - load_jit_model, -) -from lama_cleaner.model.utils import ( - make_beta_schedule, - timestep_embedding, -) +# torch.manual_seed(42) + + +def conditional_autocast(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available(): + with torch.cuda.amp.autocast(): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + return wrapper + LDM_ENCODE_MODEL_URL = os.environ.get( "LDM_ENCODE_MODEL_URL", @@ -110,7 +114,7 @@ class DDPM(nn.Module): alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) + def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -269,7 +273,7 @@ class LDM(InpaintModel): ] return all([os.path.exists(it) for it in model_paths]) - @torch.cuda.amp.autocast() + @conditional_autocast def forward(self, image, mask, config: Config): """ image: [H, W, C] RGB