Apply amp.autocast only when CUDA is available

main
Uminosachi 2024-06-30 12:14:47 +09:00
parent 03a1fabb16
commit ae6cc075f6
1 changed files with 19 additions and 15 deletions

View File

@ -1,26 +1,30 @@
import os import os
from functools import wraps
import numpy as np import numpy as np
import torch import torch
from loguru import logger import torch.nn as nn
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.ddim_sampler import DDIMSampler from lama_cleaner.model.ddim_sampler import DDIMSampler
from lama_cleaner.model.plms_sampler import PLMSSampler from lama_cleaner.model.plms_sampler import PLMSSampler
from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding
from lama_cleaner.schema import Config, LDMSampler from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42) # torch.manual_seed(42)
import torch.nn as nn
from lama_cleaner.helper import (
download_model, def conditional_autocast(func):
norm_img, @wraps(func)
get_cache_path_by_url, def wrapper(*args, **kwargs):
load_jit_model, if torch.cuda.is_available():
) with torch.cuda.amp.autocast():
from lama_cleaner.model.utils import ( return func(*args, **kwargs)
make_beta_schedule, else:
timestep_embedding, return func(*args, **kwargs)
) return wrapper
LDM_ENCODE_MODEL_URL = os.environ.get( LDM_ENCODE_MODEL_URL = os.environ.get(
"LDM_ENCODE_MODEL_URL", "LDM_ENCODE_MODEL_URL",
@ -110,7 +114,7 @@ class DDPM(nn.Module):
alphas_cumprod.shape[0] == self.num_timesteps alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device)
self.register_buffer("betas", to_torch(betas)) self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
@ -269,7 +273,7 @@ class LDM(InpaintModel):
] ]
return all([os.path.exists(it) for it in model_paths]) return all([os.path.exists(it) for it in model_paths])
@torch.cuda.amp.autocast() @conditional_autocast
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
""" """
image: [H, W, C] RGB image: [H, W, C] RGB