Apply amp.autocast only when CUDA is available
parent
03a1fabb16
commit
ae6cc075f6
|
|
@ -1,26 +1,30 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
||||||
from lama_cleaner.model.plms_sampler import PLMSSampler
|
from lama_cleaner.model.plms_sampler import PLMSSampler
|
||||||
|
from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding
|
||||||
from lama_cleaner.schema import Config, LDMSampler
|
from lama_cleaner.schema import Config, LDMSampler
|
||||||
|
|
||||||
torch.manual_seed(42)
|
# torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
|
||||||
from lama_cleaner.helper import (
|
|
||||||
download_model,
|
def conditional_autocast(func):
|
||||||
norm_img,
|
@wraps(func)
|
||||||
get_cache_path_by_url,
|
def wrapper(*args, **kwargs):
|
||||||
load_jit_model,
|
if torch.cuda.is_available():
|
||||||
)
|
with torch.cuda.amp.autocast():
|
||||||
from lama_cleaner.model.utils import (
|
return func(*args, **kwargs)
|
||||||
make_beta_schedule,
|
else:
|
||||||
timestep_embedding,
|
return func(*args, **kwargs)
|
||||||
)
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||||
"LDM_ENCODE_MODEL_URL",
|
"LDM_ENCODE_MODEL_URL",
|
||||||
|
|
@ -110,7 +114,7 @@ class DDPM(nn.Module):
|
||||||
alphas_cumprod.shape[0] == self.num_timesteps
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
), "alphas have to be defined for each timestep"
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
self.register_buffer("betas", to_torch(betas))
|
self.register_buffer("betas", to_torch(betas))
|
||||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
|
@ -269,7 +273,7 @@ class LDM(InpaintModel):
|
||||||
]
|
]
|
||||||
return all([os.path.exists(it) for it in model_paths])
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
|
|
||||||
@torch.cuda.amp.autocast()
|
@conditional_autocast
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""
|
"""
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue