mirror of https://github.com/vladmandic/automatic
120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
import math
|
|
from typing import Literal
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
try:
|
|
import scipy.stats
|
|
_scipy_available = True
|
|
except ImportError:
|
|
_scipy_available = False
|
|
|
|
def betas_for_alpha_bar(
|
|
num_diffusion_timesteps: int,
|
|
max_beta: float = 0.999,
|
|
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
if alpha_transform_type == "cosine":
|
|
def alpha_bar_fn(t):
|
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
|
elif alpha_transform_type == "laplace":
|
|
def alpha_bar_fn(t):
|
|
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
|
snr = math.exp(lmb)
|
|
return math.sqrt(snr / (1 + snr))
|
|
elif alpha_transform_type == "exp":
|
|
def alpha_bar_fn(t):
|
|
return math.exp(t * -12.0)
|
|
else:
|
|
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
|
|
|
betas = []
|
|
for i in range(num_diffusion_timesteps):
|
|
t1 = i / num_diffusion_timesteps
|
|
t2 = (i + 1) / num_diffusion_timesteps
|
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
|
return torch.tensor(betas, dtype=dtype)
|
|
|
|
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
alphas_bar = alphas_bar_sqrt**2
|
|
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
betas = 1 - alphas
|
|
return betas
|
|
|
|
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu", dtype: torch.dtype = torch.float32):
|
|
ramp = np.linspace(0, 1, n)
|
|
min_inv_rho = sigma_min ** (1 / rho)
|
|
max_inv_rho = sigma_max ** (1 / rho)
|
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
|
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
|
|
|
def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
|
|
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), n))
|
|
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
|
|
|
def get_sigmas_beta(n, sigma_min, sigma_max, alpha=0.6, beta=0.6, device="cpu", dtype: torch.dtype = torch.float32):
|
|
if not _scipy_available:
|
|
raise ImportError("scipy is required for beta sigmas")
|
|
sigmas = np.array(
|
|
[
|
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
|
for ppf in [
|
|
scipy.stats.beta.ppf(timestep, alpha, beta)
|
|
for timestep in 1 - np.linspace(0, 1, n)
|
|
]
|
|
]
|
|
)
|
|
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
|
|
|
def get_sigmas_flow(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
|
|
# Linear flow sigmas
|
|
sigmas = np.linspace(sigma_max, sigma_min, n)
|
|
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
|
|
|
def apply_shift(sigmas, shift):
|
|
return shift * sigmas / (1 + (shift - 1) * sigmas)
|
|
|
|
def get_dynamic_shift(mu, base_shift, max_shift, base_seq_len, max_seq_len):
|
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
|
b = base_shift - m * base_seq_len
|
|
return m * mu + b
|
|
|
|
def index_for_timestep(timestep, timesteps):
|
|
# Normalize inputs to numpy arrays for a robust, device-agnostic argmin
|
|
if isinstance(timestep, torch.Tensor):
|
|
timestep_np = timestep.detach().cpu().numpy()
|
|
else:
|
|
timestep_np = np.array(timestep)
|
|
|
|
if isinstance(timesteps, torch.Tensor):
|
|
timesteps_np = timesteps.detach().cpu().numpy()
|
|
else:
|
|
timesteps_np = np.array(timesteps)
|
|
|
|
# Use numpy argmin on absolute difference for stability
|
|
idx = np.abs(timesteps_np - timestep_np).argmin()
|
|
return int(idx)
|
|
|
|
def add_noise_to_sample(
|
|
original_samples: torch.Tensor,
|
|
noise: torch.Tensor,
|
|
sigmas: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
timesteps: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
step_index = index_for_timestep(timestep, timesteps)
|
|
sigma = sigmas[step_index].to(original_samples.dtype)
|
|
|
|
noisy_samples = original_samples + sigma * noise
|
|
return noisy_samples
|