Apply amp.autocast only when CUDA is available
parent
03a1fabb16
commit
ae6cc075f6
|
|
@ -1,26 +1,30 @@
|
|||
import os
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
import torch.nn as nn
|
||||
|
||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
||||
from lama_cleaner.model.plms_sampler import PLMSSampler
|
||||
from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding
|
||||
from lama_cleaner.schema import Config, LDMSampler
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from lama_cleaner.helper import (
|
||||
download_model,
|
||||
norm_img,
|
||||
get_cache_path_by_url,
|
||||
load_jit_model,
|
||||
)
|
||||
from lama_cleaner.model.utils import (
|
||||
make_beta_schedule,
|
||||
timestep_embedding,
|
||||
)
|
||||
# torch.manual_seed(42)
|
||||
|
||||
|
||||
def conditional_autocast(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.amp.autocast():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||
"LDM_ENCODE_MODEL_URL",
|
||||
|
|
@ -110,7 +114,7 @@ class DDPM(nn.Module):
|
|||
alphas_cumprod.shape[0] == self.num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
|
||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||
def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
|
|
@ -269,7 +273,7 @@ class LDM(InpaintModel):
|
|||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
@conditional_autocast
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
|
|
|
|||
Loading…
Reference in New Issue