diff --git a/CHANGELOG.md b/CHANGELOG.md index 2726a1864..4e08cb147 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ This release can be considered an LTS release before we kick off the next round - select in *scripts -> pulid* - compatible with *sdxl* - can be used in xyz grid + - *note*: this module contains several advanced features on top of original implementation - [InstantIR](https://github.com/instantX-research/InstantIR): Blind Image Restoration with Instant Generative Reference - alternative to traditional `img2img` with more control over restoration process - select in *image -> scripts -> instantir* diff --git a/modules/pulid/__init__.py b/modules/pulid/__init__.py index 000f45293..785b849c2 100644 --- a/modules/pulid/__init__.py +++ b/modules/pulid/__init__.py @@ -8,3 +8,4 @@ sys.path.append(os.path.dirname(__file__)) from pulid_sdxl import StableDiffusionXLPuLIDPipeline from pulid_utils import resize_numpy_image_long as resize import attention_processor as attention +import pulid_sampling as sampling diff --git a/modules/pulid/pulid_sampling.py b/modules/pulid/pulid_sampling.py new file mode 100644 index 000000000..9996f035a --- /dev/null +++ b/modules/pulid/pulid_sampling.py @@ -0,0 +1,571 @@ +import math +from scipy import integrate +import torch +from torch import nn +from torchdiffeq import odeint +import torchsde +from tqdm.auto import trange + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): + """Constructs an polynomial in log sigma noise schedule.""" + ramp = torch.linspace(1, 0, n, device=device) ** rho + sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0. + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + (d * dt).to(x.dtype) + return x + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + (d * dt).to(x.dtype) + if sigmas[i + 1] > 0: + x = x + (noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up).to(x.dtype) + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +class PIDStepSizeController: + """A PID controller for ODE adaptive step size control.""" + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + self.h = h + self.b1 = (pcoeff + icoeff + dcoeff) / order + self.b2 = -(pcoeff + 2 * dcoeff) / order + self.b3 = dcoeff / order + self.accept_safety = accept_safety + self.eps = eps + self.errs = [] + + def limiter(self, x): + return 1 + math.atan(x - 1) + + def propose_step(self, error): + inv_error = 1 / (float(error) + self.eps) + if not self.errs: + self.errs = [inv_error, inv_error, inv_error] + self.errs[0] = inv_error + factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = self.limiter(factor) + accept = factor >= self.accept_safety + if accept: + self.errs[2] = self.errs[1] + self.errs[1] = self.errs[0] + self.h *= factor + return accept + + +class DPMSolver(nn.Module): + """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + + def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + super().__init__() + self.model = model + self.extra_args = {} if extra_args is None else extra_args + self.eps_callback = eps_callback + self.info_callback = info_callback + + def t(self, sigma): + return -sigma.log() + + def sigma(self, t): + return t.neg().exp() + + def eps(self, eps_cache, key, x, t, *args, **kwargs): + if key in eps_cache: + return eps_cache[key], eps_cache + sigma = self.sigma(t) * x.new_ones([x.shape[0]]) + eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + if self.eps_callback is not None: + self.eps_callback() + return eps, {key: eps, **eps_cache} + + def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + x_1 = x - self.sigma(t_next) * h.expm1() * eps + return x_1, eps_cache + + def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + return x_2, eps_cache + + def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + s2 = t + r2 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) + eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) + x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + return x_3, eps_cache + + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if not t_end > t_start and eta: + raise ValueError('eta must be 0 for reverse sampling') + + m = math.floor(nfe / 3) + 1 + ts = torch.linspace(t_start, t_end, m + 1, device=x.device) + + if nfe % 3 == 0: + orders = [3] * (m - 2) + [2, 1] + else: + orders = [3] * (m - 1) + [nfe % 3] + + for i in range(len(orders)): + eps_cache = {} + t, t_next = ts[i], ts[i + 1] + if eta: + sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) + t_next_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 + else: + t_next_, su = t_next, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + denoised = x - self.sigma(t) * eps + if self.info_callback is not None: + self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + + if orders[i] == 1: + x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + elif orders[i] == 2: + x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + else: + x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + + x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) + + return x + + def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if order not in {2, 3}: + raise ValueError('order should be 2 or 3') + forward = t_end > t_start + if not forward and eta: + raise ValueError('eta must be 0 for reverse sampling') + h_init = abs(h_init) * (1 if forward else -1) + atol = torch.tensor(atol) + rtol = torch.tensor(rtol) + s = t_start + x_prev = x + accept = True + pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) + info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + + while s < t_end - 1e-5 if forward else s > t_end + 1e-5: + eps_cache = {} + t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + if eta: + sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) + t_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 + else: + t_, su = t, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + denoised = x - self.sigma(s) * eps + + if order == 2: + x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + else: + x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) + error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 + accept = pid.propose_step(error) + if accept: + x_prev = x_low + x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) + s = t + info['n_accept'] += 1 + else: + info['n_reject'] += 1 + info['nfe'] += order + info['steps'] += 1 + + if self.info_callback is not None: + self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + + return x, info + + +@torch.no_grad() +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigma_down == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + # Noise addition + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + return x + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x + + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +@torch.no_grad() +def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """DPM-Solver++(3M) SDE.""" + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + denoised_1, denoised_2 = None, None + h_1, h_2 = None, None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + h_eta = h * (eta + 1) + + x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + + if h_2 is not None: + r0 = h_1 / h + r1 = h_2 / h + d1_0 = (denoised - denoised_1) / r0 + d1_1 = (denoised_1 - denoised_2) / r1 + d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + d2 = (d1_0 - d1_1) / (r0 + r1) + phi_2 = h_eta.neg().expm1() / h_eta + 1 + phi_3 = phi_2 / h_eta - 0.5 + x = x + phi_2 * d1 - phi_3 * d2 + elif h_1 is not None: + r = h_1 / h + d = (denoised - denoised_1) / r + phi_2 = h_eta.neg().expm1() / h_eta + 1 + x = x + phi_2 * d + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + + denoised_1, denoised_2 = denoised, denoised_1 + h_1, h_2 = h, h_1 + return x diff --git a/modules/pulid/pulid_sdxl.py b/modules/pulid/pulid_sdxl.py index 0651efab4..de650b839 100644 --- a/modules/pulid/pulid_sdxl.py +++ b/modules/pulid/pulid_sdxl.py @@ -19,13 +19,12 @@ from insightface.app import FaceAnalysis from eva_clip import create_model_and_transforms from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from encoders_transformer import IDFormer -from pulid_utils import sample_dpmpp_2m, sample_dpmpp_sde from attention_processor import AttnProcessor2_0 as AttnProcessor from attention_processor import IDAttnProcessor2_0 as IDAttnProcessor class StableDiffusionXLPuLIDPipeline: - def __init__(self, pipe: StableDiffusionXLPipeline, device: torch.device, sampler='dpmpp_sde', cache_dir=None): + def __init__(self, pipe: StableDiffusionXLPipeline, device: torch.device, sampler=None, cache_dir=None): super().__init__() self.device = device self.pipe = pipe @@ -90,12 +89,16 @@ class StableDiffusionXLPuLIDPipeline: self.log_sigmas = self.sigmas.log() self.sigma_data = 1.0 + if sampler is not None: + self.sampler = sampler + """ if sampler == 'dpmpp_sde': self.sampler = sample_dpmpp_sde elif sampler == 'dpmpp_2m': self.sampler = sample_dpmpp_2m else: raise NotImplementedError(f'sampler {sampler} not implemented') + """ @property def sigma_min(self): diff --git a/modules/pulid/pulid_utils.py b/modules/pulid/pulid_utils.py index 1a8d3ff06..fd7338b5e 100644 --- a/modules/pulid/pulid_utils.py +++ b/modules/pulid/pulid_utils.py @@ -6,10 +6,7 @@ import random import cv2 import numpy as np import torch -import torch.nn.functional as F -import torchsde from torchvision.utils import make_grid -from tqdm.auto import trange from transformers import PretrainedConfig @@ -21,10 +18,6 @@ def seed_everything(seed): torch.cuda.manual_seed_all(seed) -def is_torch2_available(): - return hasattr(F, "scaled_dot_product_attention") - - def instantiate_from_config(config): if "target" not in config: if config == '__is_first_stage__' or config == "__is_unconditional__": @@ -166,172 +159,3 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): if len(result) == 1: result = result[0] return result - - -# We didn't find a correct configuration to make the diffusers scheduler align with dpm++2m (karras) in ComfyUI, -# so we copied the ComfyUI code directly. - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - expanded = x[(...,) + (None,) * dims_to_append] - # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. - # https://github.com/pytorch/pytorch/issues/84364 - return expanded.detach().clone() if expanded.device.type == 'mps' else expanded - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma, x.ndim) - - -def get_ancestral_step(sigma_from, sigma_to, eta=1.0): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - if not eta: - return sigma_to, 0.0 - sigma_up = min(sigma_to, eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - return sigma_down, sigma_up - - -class BatchedBrownianTree: - """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" - - def __init__(self, x, t0, t1, seed=None, **kwargs): - self.cpu_tree = True - if "cpu" in kwargs: - self.cpu_tree = kwargs.pop("cpu") - t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.get('w0', torch.zeros_like(x)) - if seed is None: - seed = torch.randint(0, 2**63 - 1, []).item() - self.batched = True - try: - assert len(seed) == x.shape[0] - w0 = w0[0] - except TypeError: - seed = [seed] - self.batched = False - if self.cpu_tree: - self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] - else: - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] - - @staticmethod - def sort(a, b): - return (a, b, 1) if a < b else (b, a, -1) - - def __call__(self, t0, t1): - t0, t1, sign = self.sort(t0, t1) - if self.cpu_tree: - w = torch.stack( - [tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees] - ) * (self.sign * sign) - else: - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) - - return w if self.batched else w[0] - - -class BrownianTreeNoiseSampler: - """A noise sampler backed by a torchsde.BrownianTree. - - Args: - x (Tensor): The tensor whose shape, device and dtype to use to generate - random samples. - sigma_min (float): The low end of the valid interval. - sigma_max (float): The high end of the valid interval. - seed (int or List[int]): The random seed. If a list of seeds is - supplied instead of a single integer, then the noise sampler will - use one BrownianTree per batch item, each with its own seed. - transform (callable): A function that maps sigma to the sampler's - internal timestep. - """ - - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): - self.transform = transform - t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) - self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) - - def __call__(self, sigma, sigma_next): - t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) - return self.tree(t0, t1) / (t1 - t0).abs().sqrt() - - -@torch.no_grad() -def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): - """DPM-Solver++(2M).""" - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - sigma_fn = lambda t: t.neg().exp() # pylint: disable=unnecessary-lambda-assignment - t_fn = lambda sigma: sigma.log().neg() # pylint: disable=unnecessary-lambda-assignment - old_denoised = None - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) - h = t_next - t - if old_denoised is None or sigmas[i + 1] == 0: - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised - else: - h_last = t - t_fn(sigmas[i - 1]) - r = h_last / h - denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d - old_denoised = denoised - return x - - -@torch.no_grad() -def sample_dpmpp_sde( - model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None, r=1 / 2 -): - """DPM-Solver++ (stochastic).""" - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - seed = extra_args.get("seed", None) - noise_sampler = ( - BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=False) - if noise_sampler is None - else noise_sampler - ) - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - sigma_fn = lambda t: t.neg().exp() # pylint: disable=unnecessary-lambda-assignment - t_fn = lambda sigma: sigma.log().neg() # pylint: disable=unnecessary-lambda-assignment - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - # Euler method - d = to_d(x, sigmas[i], denoised) - dt = sigmas[i + 1] - sigmas[i] - x = x + d * dt - else: - # DPM-Solver++ - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) - h = t_next - t - s = t + h * r - fac = 1 / (2 * r) - - # Step 1 - sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) - s_ = t_fn(sd) - x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised - x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su - denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) - - # Step 2 - sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) - t_next_ = t_fn(sd) - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d - x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su - return x diff --git a/modules/sd_models.py b/modules/sd_models.py index 3ca081d6c..9dd87cec1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -291,10 +291,13 @@ def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True): def set_accelerate_to_module(model): - for k in model._internal_dict.keys(): # pylint: disable=protected-access - component = getattr(model, k, None) - if isinstance(component, torch.nn.Module): - component.has_accelerate = True + if hasattr(model, "pipe"): + set_accelerate_to_module(model.pipe) + if hasattr(model, "_internal_dict"): + for k in model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(model, k, None) + if isinstance(component, torch.nn.Module): + component.has_accelerate = True def set_accelerate(sd_model): @@ -397,6 +400,10 @@ def apply_balanced_offload(sd_model): return module def apply_balanced_offload_to_module(pipe): + if hasattr(pipe, "pipe"): + apply_balanced_offload_to_module(pipe.pipe) + if not hasattr(pipe, "_internal_dict"): + return for module_name in pipe._internal_dict.keys(): # pylint: disable=protected-access module = getattr(pipe, module_name, None) if isinstance(module, torch.nn.Module): diff --git a/scripts/pulid_ext.py b/scripts/pulid_ext.py index d31c18164..54189cc05 100644 --- a/scripts/pulid_ext.py +++ b/scripts/pulid_ext.py @@ -30,8 +30,6 @@ class Script(scripts.Script): install('insightface', 'insightface', ignore=False) install('albumentations==1.4.3', 'albumentations', ignore=False, reinstall=True) install('pydantic==1.10.15', 'pydantic', ignore=False, reinstall=True) - # if not installed('apex', reload=False, quiet=True): - # install('apex', 'apex', ignore=False) def register(self): # register xyz grid elements def apply_field(field): @@ -74,8 +72,8 @@ class Script(scripts.Script): strength = gr.Slider(label = 'Strength', value = 0.8, mininimum = 0, maximum = 1, step = 0.01) zero = gr.Slider(label = 'Zero', value = 20, mininimum = 0, maximum = 80, step = 1) with gr.Row(): - sampler = gr.Dropdown(label="Sampler", choices=['dpmpp_sde', 'dpmpp_2m'], value='dpmpp_sde', visible=True) - ortho = gr.Dropdown(label="Ortho", choices=['off', 'v1', 'v2'], value='v2', visible=True) + sampler = gr.Dropdown(label="Sampler", value='dpmpp_sde', choices=['dpmpp_2m', 'dpmpp_2m_sde', 'dpmpp_2s_ancestral', 'dpmpp_3m_sde', 'dpmpp_sde', 'euler', 'euler_ancestral']) + ortho = gr.Dropdown(label="Ortho", choices=['off', 'v1', 'v2'], value='v2') with gr.Row(): files = gr.File(label='Input images', file_count='multiple', file_types=['image'], type='file', interactive=True, height=100) with gr.Row(): @@ -124,16 +122,17 @@ class Script(scripts.Script): strength = getattr(p, 'pulid_strength', strength) zero = getattr(p, 'pulid_zero', zero) ortho = getattr(p, 'pulid_ortho', ortho) + sampler = getattr(p, 'pulid_sampler', sampler) + sampler_fn = getattr(self.pulid.sampling, f'sample_{sampler}', None) if shared.sd_model_type == 'sdxl' and not hasattr(shared.sd_model, 'pipe'): try: stdout = io.StringIO() - ctx = contextlib.nullcontext if debug else contextlib.redirect_stdout(stdout) + ctx = contextlib.nullcontext() if debug else contextlib.redirect_stdout(stdout) with ctx: shared.sd_model = self.pulid.StableDiffusionXLPuLIDPipeline( pipe =shared.sd_model, device=devices.device, - sampler=sampler, cache_dir=shared.opts.hfcache_dir, ) shared.sd_model.no_recurse = True @@ -146,6 +145,7 @@ class Script(scripts.Script): errors.display(e, 'PuLID') return None + shared.sd_model.sampler = sampler_fn shared.log.info(f'PuLID: class={shared.sd_model.__class__.__name__} strength={strength} zero={zero} ortho={ortho} sampler={sampler} images={[i.shape for i in images]}') self.pulid.attention.NUM_ZERO = zero self.pulid.attention.ORTHO = ortho == 'v1'