61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
from typing import Any, Callable, Optional
|
|
import torch
|
|
from k_diffusion.external import CompVisDenoiser
|
|
from k_diffusion import sampling
|
|
from torch import nn
|
|
|
|
class CFGDenoiser(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
x_in = torch.cat([x] * 2)
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
cond_in = torch.cat([uncond, cond])
|
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
return uncond + (cond - uncond) * cond_scale
|
|
|
|
|
|
def sampler_fn(
|
|
c: torch.Tensor,
|
|
uc: torch.Tensor,
|
|
args,
|
|
model_wrap: CompVisDenoiser,
|
|
init_latent: Optional[torch.Tensor] = None,
|
|
t_enc: Optional[torch.Tensor] = None,
|
|
device=torch.device("cpu")
|
|
if not torch.cuda.is_available()
|
|
else torch.device("cuda"),
|
|
cb: Callable[[Any], None] = None,
|
|
) -> torch.Tensor:
|
|
shape = [args.C, args.H // args.f, args.W // args.f]
|
|
sigmas: torch.Tensor = model_wrap.get_sigmas(args.steps)
|
|
if args.use_init:
|
|
sigmas = sigmas[len(sigmas) - t_enc - 1 :]
|
|
x = (
|
|
init_latent
|
|
+ torch.randn([args.n_samples, *shape], device=device) * sigmas[0]
|
|
)
|
|
else:
|
|
x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]
|
|
sampler_args = {
|
|
"model": CFGDenoiser(model_wrap),
|
|
"x": x,
|
|
"sigmas": sigmas,
|
|
"extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale},
|
|
"disable": False,
|
|
"callback": cb,
|
|
}
|
|
sampler_map = {
|
|
"klms": sampling.sample_lms,
|
|
"dpm2": sampling.sample_dpm_2,
|
|
"dpm2_ancestral": sampling.sample_dpm_2_ancestral,
|
|
"heun": sampling.sample_heun,
|
|
"euler": sampling.sample_euler,
|
|
"euler_ancestral": sampling.sample_euler_ancestral,
|
|
}
|
|
|
|
samples = sampler_map[args.sampler](**sampler_args)
|
|
return samples
|