Merge branch 'main' into model_storage

pull/186/head
kabachuha 2023-07-02 17:37:43 +03:00
commit 412909b8eb
12 changed files with 1874 additions and 81 deletions

View File

@ -67,7 +67,7 @@ def t2v_api(_, app: FastAPI):
return JSONResponse(content={"version": get_t2v_version()})
@app.post("/t2v/run")
async def t2v_run(prompt: str, n_prompt: Union[str, None] = None, steps: Union[int, None] = None, frames: Union[int, None] = None, seed: Union[int, None] = None, \
async def t2v_run(prompt: str, n_prompt: Union[str, None] = None, sampler: Union[str, None] = None, steps: Union[int, None] = None, frames: Union[int, None] = None, seed: Union[int, None] = None, \
cfg_scale: Union[int, None] = None, width: Union[int, None] = None, height: Union[int, None] = None, eta: Union[float, None] = None, batch_count: Union[int, None] = None, \
do_vid2vid:bool = False, vid2vid_input: Union[UploadFile, None] = None,strength: Union[float, None] = None,vid2vid_startFrame: Union[int, None] = None, \
inpainting_image: Union[UploadFile, None] = None, inpainting_frames: Union[int, None] = None, inpainting_weights: Union[str, None] = None, \
@ -131,6 +131,7 @@ def t2v_api(_, app: FastAPI):
d.prompt,#prompt
d.n_prompt,#n_prompt
d.sampler,#sampler
d.steps,#steps
d.frames,#frames
d.seed,#seed
@ -143,6 +144,7 @@ def t2v_api(_, app: FastAPI):
# The same, but for vid2vid. Will deduplicate later
d.prompt,#prompt
d.n_prompt,#n_prompt
d.sampler,#sampler
d.steps,#steps
d.frames,#frames
d.seed,#seed

View File

@ -146,7 +146,7 @@ def process_modelscope(args_dict):
state.job_count = args.batch_count
for batch in pbar:
state.job_no = batch + 1
state.job_no = batch
if state.skipped:
state.skipped = False
@ -207,7 +207,7 @@ def process_modelscope(args_dict):
args.strength = 1
samples, _ = pipe.infer(args.prompt, args.n_prompt, args.steps, args.frames, args.seed + batch if args.seed != -1 else -1, args.cfg_scale,
args.width, args.height, args.eta, cpu_vae, device, latents, skip_steps=skip_steps, mask=mask)
args.width, args.height, args.eta, cpu_vae, device, latents, strength=args.strength, skip_steps=skip_steps, mask=mask, is_vid2vid=args.do_vid2vid, sampler=args.sampler)
if batch > 0:
outdir_current = os.path.join(get_outdir(), f"{init_timestring}_{batch}")

View File

@ -23,16 +23,19 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from os import path as osp
from modules.shared import opts
from functools import partial
from tqdm import tqdm
from modules.prompt_parser import reconstruct_cond_batch
from modules.shared import state
from modules.sd_samplers_common import InterruptedException
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
from ldm.modules.diffusionmodules.util import make_beta_schedule
__all__ = ['UNetSD']
@ -112,7 +115,8 @@ class UNetSD(nn.Module):
use_checkpoint=False,
use_image_dataset=False,
use_fps_condition=False,
use_sim_mask=False):
use_sim_mask=False,
parameterization="eps"):
embed_dim = dim * 4
num_heads = num_heads if num_heads else dim // 32
super(UNetSD, self).__init__()
@ -135,6 +139,8 @@ class UNetSD(nn.Module):
self.use_image_dataset = use_image_dataset
self.use_fps_condition = use_fps_condition
self.use_sim_mask = use_sim_mask
self.parameterization = parameterization
self.v_posterior = 0
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
@ -319,6 +325,64 @@ class UNetSD(nn.Module):
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
# Taken from DDPM
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
def forward(
self,
x,
@ -488,7 +552,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(x.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
import xformers
out = xformers.ops.memory_efficient_attention(
@ -1266,8 +1330,8 @@ class GaussianDiffusion(object):
r"""Distribution of p(x_{t-1} | x_t).
"""
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
if guide_scale is None or guide_scale == 1:
out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)

View File

@ -19,8 +19,11 @@ import torch.cuda.amp as amp
from einops import rearrange
import cv2
from modelscope.t2v_model import UNetSD, AutoencoderKL, GaussianDiffusion, beta_schedule
from modules import devices
from modules import devices, shared
from modules import prompt_parser
from samplers.uni_pc.sampler import UniPCSampler
from samplers.samplers_common import Txt2VideoSampler
from samplers.samplers_common import available_samplers
__all__ = ['TextToVideoSynthesis']
@ -86,6 +89,7 @@ class TextToVideoSynthesis():
num_res_blocks=cfg['unet_res_blocks'],
attn_scales=cfg['unet_attn_scales'],
dropout=cfg['unet_dropout'],
parameterization=cfg['mean_type'],
temporal_attention=cfg['temporal_attention'])
self.sd_model.load_state_dict(
torch.load(
@ -97,20 +101,17 @@ class TextToVideoSynthesis():
self.sd_model.eval()
if not devices.has_mps() or torch.cuda.is_available() == True:
self.sd_model.half()
# Initialize diffusion
betas = beta_schedule(
'linear_sd',
cfg['num_timesteps'],
init_beta=0.00085,
last_beta=0.0120)
self.diffusion = GaussianDiffusion(
betas=betas,
mean_type=cfg['mean_type'],
var_type=cfg['var_type'],
loss_type=cfg['loss_type'],
rescale_timesteps=False)
self.sd_model.register_schedule(given_betas=betas.numpy())
self.diffusion = Txt2VideoSampler(self.sd_model, shared.device, betas=betas)
# Initialize autoencoder
ddconfig = {
'double_z': True,
@ -192,7 +193,26 @@ class TextToVideoSynthesis():
return out
# @torch.compile()
def infer(self, prompt, n_prompt, steps, frames, seed, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device=torch.device('cpu'), latents=None, skip_steps=0,strength=0,mask=None):
def infer(
self,
prompt,
n_prompt,
steps,
frames,
seed,
scale,
width=256,
height=256,
eta=0.0,
cpu_vae='GPU (half precision)',
device=torch.device('cpu'),
latents=None,
skip_steps=0,
strength=0,
mask=None,
is_vid2vid=False,
sampler=available_samplers[0].name
):
vars = locals()
vars.pop('self')
vars.pop('latents')
@ -227,6 +247,7 @@ class TextToVideoSynthesis():
self.device = device
self.clip_encoder.to(self.device)
self.clip_encoder.device = self.device
steps = steps - skip_steps
c, uc = self.preprocess(prompt, n_prompt, steps)
if self.keep_in_vram != "All":
self.clip_encoder.to("cpu")
@ -236,36 +257,37 @@ class TextToVideoSynthesis():
latents=latents.half() if 'half precision' in cpu_vae and latents is not None else latents
# synthesis
strength = None if strength == 0.0 else strength
strength = None if (strength == 0.0 and not is_vid2vid) else strength
with torch.no_grad():
num_sample = 1
max_frames = frames
latent_h, latent_w = height // 8, width // 8
self.sd_model.to(self.device)
if latents == None:
self.noise_gen.manual_seed(seed)
latents = torch.randn(num_sample, 4, max_frames, latent_h,
latent_w, generator=self.noise_gen).to(
self.device)
else:
latents.to(self.device)
print("latents", latents.shape, torch.mean(
latents), torch.std(latents))
channels = 4
max_frames= frames
latents, noise, shape = self.diffusion.get_noise(
num_sample,
channels,
max_frames,
height,
width,
seed=seed,
latents=latents
)
with amp.autocast(enabled=True):
self.sd_model.to(self.device)
x0 = self.diffusion.ddim_sample_loop(
noise=latents, # shape: b c f h w
model=self.sd_model,
c=c,
uc=uc,
num_sample=1,
guide_scale=scale,
ddim_timesteps=steps,
self.diffusion.get_sampler(sampler, return_sampler=False)
x0 = self.diffusion.sample_loop(
steps=steps,
strength=strength,
eta=eta,
percentile=strength,
skip_steps=skip_steps,
conditioning=c,
unconditional_conditioning=uc,
batch_size=num_sample,
guidance_scale=scale,
latents=latents,
shape=shape,
noise=noise,
is_vid2vid=is_vid2vid,
sampler_name=sampler,
mask=mask
)
@ -435,3 +457,4 @@ def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
images = [(image.numpy() * 255).astype('uint8')
for image in images] # f h w c
return images

View File

@ -0,0 +1,300 @@
import torch
from modelscope.t2v_model import _i
from t2v_helpers.general_utils import reconstruct_conds
class GaussianDiffusion(object):
r""" Diffusion Model for DDIM.
"Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon.
See https://arxiv.org/abs/2010.02502
"""
def __init__(self,
model,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
epsilon=1e-12,
rescale_timesteps=False,
**kwargs):
# check input
self.check_input_vars(betas, mean_type, var_type, loss_type)
self.model = model
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.epsilon = epsilon
self.rescale_timesteps = rescale_timesteps
# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:],alphas.new_zeros([1])])
# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def check_input_vars(self, betas, mean_type, var_type, loss_type):
mean_types = ['x0', 'x_{t-1}', 'eps']
var_types = ['learned', 'learned_range', 'fixed_large', 'fixed_small']
loss_types = ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in mean_types
assert var_type in var_types
assert loss_type in loss_types
def validate_model_kwargs(self, model_kwargs):
"""
Use the original implementation of passing model kwargs to the model.
eg: model_kwargs=[{'y':c_i}, {'y':uc_i,}]
"""
if len(model_kwargs) > 0:
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
def get_time_steps(self, ddim_timesteps, batch_size=1, step=None):
b = batch_size
# Get thhe full timestep range
arange_steps = (1 + torch.arange(0, self.num_timesteps, ddim_timesteps))
steps = arange_steps.clamp(0, self.num_timesteps - 1)
timesteps = steps.flip(0).to(self.model.device)
if step is not None:
# Get the current timestep during a sample loop
timesteps = torch.full((b, ), timesteps[step], dtype=torch.long, device=self.model.device)
return timesteps
def add_noise(self, xt, noise, t):
noisy_sample = self.sqrt_alphas_cumprod[t.cpu()].to(self.model.device) * \
xt + noise * self.sqrt_one_minus_alphas_cumprod[t.cpu()].to(self.model.device)
return noisy_sample
def get_dim(self, y_out):
is_fixed = self.var_type.startswith('fixed')
return y_out.size(1) if is_fixed else y_out.size(1) // 2
def fixed_small_variance(self, xt, t):
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return var, log_var
def mean_x0(self, xt, t, x_out):
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * x_out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
return x0, mu
def restrict_range_x0(self, percentile, x0, clamp=False):
if not clamp:
assert percentile > 0 and percentile <= 1 # e.g., 0.995
s = torch.quantile(x0.flatten(1).abs(), percentile,dim=1)
s.clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
else:
x0 = x0.clamp(-clamp, clamp)
return x0
def is_unconditional(self, guide_scale):
return guide_scale is None or guide_scale == 1
def do_classifier_guidance(self, y_out, u_out, guidance_scale):
"""
y_out: Condition
u_out: Unconditional
"""
dim = self.get_dim(y_out)
a = u_out[:, :dim]
b = guidance_scale * (y_out[:, :dim] - u_out[:, :dim])
c = y_out[:, dim:]
out = torch.cat([a + b, c], dim=1)
return out
def p_mean_variance(self,
xt,
t,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
conditioning=None,
unconditional_conditioning=None,
only_x0=True,
**kwargs):
r"""Distribution of p(x_{t-1} | x_t)."""
# predict distribution
if self.is_unconditional(guide_scale):
out = self.model(xt, self._scale_timesteps(t), conditioning)
else:
# classifier-free guidance
if model_kwargs != {}:
self.validate_model_kwargs(model_kwargs)
conditioning = model_kwargs[0]
unconditional_conditioning = model_kwargs[1]
y_out = self.model(xt, self._scale_timesteps(t), conditioning)
u_out = self.model(xt, self._scale_timesteps(t), unconditional_conditioning)
out = self.do_classifier_guidance(y_out, u_out, guide_scale)
# compute variance
if self.var_type == 'fixed_small':
var, log_var = self.fixed_small_variance(xt, t)
# compute mean and x0
if self.mean_type == 'eps':
x0, mu = self.mean_x0(xt, t, out)
# restrict the range of x0
if percentile is not None:
x0 = self.restrict_range_x0(percentile, x0)
elif clamp is not None:
x0 = self.restrict_range_x0(percentile, x0, clamp=True)
if only_x0:
return x0
else:
return mu, var, log_var, x0
def q_posterior_mean_variance(self, x0, xt, t):
r"""Distribution of q(x_{t-1} | x_t, x_0).
"""
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t
def get_eps(self, xt, x0, t, alpha, condition_fn, model_kwargs={}):
# x0 -> eps
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
if condition_fn is not None:
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
return eps, x0
@torch.no_grad()
def sample(self,
x_T=None,
S=5,
shape=None,
conditioning=None,
unconditional_conditioning=None,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
unconditional_guidance_scale=None,
eta=0.0,
callback=None,
mask=None,
**kwargs):
r"""Sample from p(x_{t-1} | x_t) using DDIM.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
# Shape must exist to sample
if shape is None and x_T is None:
assert "Shape must exists to sample from noise"
# Assign variables for sampling
steps = S
stride = self.num_timesteps // steps
guide_scale = unconditional_guidance_scale
original_latents = None
if x_T is None:
xt = torch.randn(shape, device=self.model.device)
else:
xt = x_T.clone()
original_latents = xt
timesteps = self.get_time_steps(stride, xt.shape[0])
for step in range(0, steps):
c, uc = reconstruct_conds(conditioning, unconditional_conditioning, step)
t = self.get_time_steps(stride, xt.shape[0], step=step)
# predict distribution of p(x_{t-1} | x_t)
x0 = self.p_mean_variance(
xt,
t,
model_kwargs,
clamp,
percentile,
guide_scale,
conditioning=c,
unconditional_conditioning=uc,
**kwargs
)
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
eps, x0 = self.get_eps(xt, x0, t, alphas, condition_fn)
a = (1 - alphas_prev) / (1 - alphas)
b = (1 - alphas / alphas_prev)
sigmas = eta * torch.sqrt(a * b)
# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
xt = xt_1
if hasattr(self, 'inpaint_masking') and mask is not None:
add_noise_args = {
"xt":xt,
"noise": torch.randn_like(xt),
"t": timesteps[(step - 1) + 1]
}
self.inpaint_masking(xt, step, steps, mask, self.add_noise, add_noise_args)
if callback is not None:
callback(step)
return xt

View File

@ -0,0 +1,306 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from modules import shared
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
from t2v_helpers.general_utils import reconstruct_conds
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
B, C, F, H, W = shape
size = (B, C, F, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = shared.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
#print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, disable=True)
for i, step in enumerate(iterator):
c, uc = reconstruct_conds(cond, unconditional_conditioning, step)
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
#if mask is not None:
# assert x0 is not None
# img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
# img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs, _ = self.p_sample_ddim(img, c, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
dynamic_threshold=dynamic_threshold)
img = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
return outs
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model(x, t, c)
else:
noise = self.model(x, t, c)
noise_uncond = self.model(x, t, unconditional_conditioning)
model_output = noise_uncond + unconditional_guidance_scale * (noise - noise_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None, *args, **kwargs):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc='Decoding image', total=total_steps, disable=True)
x_dec = x_latent
for i, step in enumerate(iterator):
c, uc = reconstruct_conds(cond, unconditional_conditioning, step)
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, c, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc)
if callback: callback(i)
return x_dec

View File

@ -0,0 +1,207 @@
import torch
from samplers.ddim.sampler import DDIMSampler
from samplers.ddim.gaussian_sampler import GaussianDiffusion
from samplers.uni_pc.sampler import UniPCSampler
from tqdm import tqdm
from modules.shared import state
from modules.sd_samplers_common import InterruptedException
def get_height_width(h, w, divisor):
return h // divisor, w // divisor
def get_tensor_shape(batch_size, channels, frames, h, w, latents=None):
if latents is None:
return (batch_size, channels, frames, h, w)
return latents.shape
def inpaint_masking(xt, step, steps, mask, add_noise_cb, noise_cb_args):
if mask is not None and step < steps - 1:
#convert mask to 0,1 valued based on step
v = (steps - step - 1) / steps
binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask))
noise_to_add = add_noise_cb(**noise_cb_args)
to_inpaint = noise_to_add
xt = to_inpaint * (1 - binary_mask) + xt * binary_mask
class SamplerStepCallback(object):
def __init__(self, sampler_name: str, total_steps: int):
self.sampler_name = sampler_name
self.total_steps = total_steps
self.current_step = 0
self.progress_bar = tqdm(desc=self.progress_msg(sampler_name, total_steps), total=total_steps)
def progress_msg(self, name, total_steps=None):
total_steps = total_steps if total_steps is not None else self.total_steps
state.sampling_steps = total_steps
return f"Sampling Using {name} for {total_steps} steps."
def set_webui_step(self, step):
state.sampling_step = step
def is_finished(self, step):
if step >= self.total_steps:
self.progress_bar.close()
self.current_step = 0
def interrupt(self):
return state.interrupted or state.skipped
def cancel(self):
raise InterruptedException
def update(self, step):
self.set_webui_step(step)
if self.interrupt():
self.cancel()
self.progress_bar.set_description(self.progress_msg(self.sampler_name))
self.progress_bar.update(1)
self.is_finished(step)
def __call__(self,*args, **kwargs):
self.current_step += 1
step = self.current_step
self.update(step)
class SamplerBase(object):
def __init__(self, name: str, Sampler, frame_inpaint_support=False):
self.name = name
self.Sampler = Sampler
self.frame_inpaint_support = frame_inpaint_support
def register_buffers_to_model(self, sd_model, betas, device):
self.alphas = 1. - betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
setattr(sd_model, 'device', device)
setattr(sd_model, 'betas', betas)
setattr(sd_model, 'alphas_cumprod', self.alphas_cumprod)
def init_sampler(self, sd_model, betas, device, **kwargs):
self.register_buffers_to_model(sd_model, betas, device)
return self.Sampler(sd_model, betas=betas, **kwargs)
available_samplers = [
SamplerBase("DDIM_Gaussian", GaussianDiffusion, True),
SamplerBase("DDIM", DDIMSampler),
SamplerBase("UniPC", UniPCSampler),
]
class Txt2VideoSampler(object):
def __init__(self, sd_model, device, betas=None, sampler_name="UniPC"):
self.sd_model = sd_model
self.device = device
self.noise_gen = torch.Generator(device='cpu')
self.sampler_name = sampler_name
self.betas = betas
self.sampler = self.get_sampler(sampler_name, betas=self.betas)
def get_noise(self, num_sample, channels, frames, height, width, latents=None, seed=1):
if latents is not None:
latents.to(self.device)
print(f"Using input latents. Shape: {latents.shape}, Mean: {torch.mean(latents)}, Std: {torch.std(latents)}")
else:
print("Sampling random noise.")
num_sample = 1
max_frames = frames
latent_h, latent_w = get_height_width(height, width, 8)
shape = get_tensor_shape(num_sample, channels, max_frames, latent_h, latent_w, latents)
self.noise_gen.manual_seed(seed)
noise = torch.randn(shape, generator=self.noise_gen).to(self.device)
return latents, noise, shape
def encode_latent(self, latent, noise, strength, steps):
encoded_latent = None
denoise_steps = None
if hasattr(self.sampler, 'unipc_encode'):
encoded_latent = self.sampler.unipc_encode(latent, self.device, strength, steps, noise=noise)
if hasattr(self.sampler, 'stochastic_encode'):
denoise_steps = int(strength * steps)
timestep = torch.tensor([denoise_steps] * int(latent.shape[0])).to(self.device)
self.sampler.make_schedule(steps)
encoded_latent = self.sampler.stochastic_encode(latent, timestep, noise=noise).to(dtype=latent.dtype)
self.sampler.sample = self.sampler.decode
if hasattr(self.sampler, 'add_noise'):
denoise_steps = int(strength * steps)
timestep = self.sampler.get_time_steps(denoise_steps, latent.shape[0])
encoded_latent = self.sampler.add_noise(latent, noise, timestep[0].cpu())
if encoded_latent is None:
assert "Could not find the appropriate function to encode the input latents"
return encoded_latent, denoise_steps
def get_sampler(self, sampler_name: str, betas=None, return_sampler=True):
betas = betas if betas is not None else self.betas
for Sampler in available_samplers:
if sampler_name == Sampler.name:
sampler = Sampler.init_sampler(self.sd_model, betas=betas, device=self.device)
if Sampler.frame_inpaint_support:
setattr(sampler, 'inpaint_masking', inpaint_masking)
if return_sampler:
return sampler
else:
self.sampler = sampler
return
raise ValueError(f"Sample {sampler_name} does not exist.")
def sample_loop(
self,
steps,
strength,
conditioning,
unconditional_conditioning,
batch_size,
latents=None,
shape=None,
noise=None,
is_vid2vid=False,
guidance_scale=1,
eta=0,
mask=None,
sampler_name="DDIM"
):
denoise_steps = None
# Assume that we are adding noise to existing latents (Image, Video, etc.)
if latents is not None and is_vid2vid:
latents, denoise_steps = self.encode_latent(latents, noise, strength, steps)
# Create a callback that handles counting each step
sampler_callback = SamplerStepCallback(sampler_name, steps)
# Predict the noise sample
x0 = self.sampler.sample(
S=steps,
conditioning=conditioning,
strength=strength,
unconditional_conditioning=unconditional_conditioning,
batch_size=batch_size,
x_T=latents if latents is not None else noise,
x_latent=latents,
t_start=denoise_steps,
unconditional_guidance_scale=guidance_scale,
shape=shape,
callback=sampler_callback,
cond=conditioning,
eta=eta,
mask=mask
)
return x0

View File

@ -0,0 +1,90 @@
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def unipc_encode(self, latent, device, strength, steps, noise=None):
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
uni_pc = UniPC(None, ns, predict_x0=True, thresholding=False, variant='bh1')
t_0 = 1. / ns.total_N
timesteps = uni_pc.get_time_steps("time_uniform", strength, t_0, steps, device)
timesteps = timesteps[0].expand((latent.shape[0]))
noisy_latent = uni_pc.unipc_encode(latent, timesteps, noise=noise)
return noisy_latent
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
strength=None,
eta=0.,
mask=None,
x0=None,
temperature=1.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
# sampling
B, C, F, H, W = shape
size = (B, C, F, H, W)
if x_T is None:
img = torch.randn(size, device=self.model.device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model(x, t, c),
ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant='bh1')
x = uni_pc.sample(
img,
steps=S,
t_start=strength,
skip_type="time_uniform",
method="multistep",
order=3,
lower_order_final=True,
initial_corrector=True,
callback=callback
)
return x.to(self.model.device)

View File

@ -0,0 +1,800 @@
import torch
import torch.nn.functional as F
import math
from einops import rearrange,repeat
from modules.shared import state
from t2v_helpers.general_utils import reconstruct_conds
class NoiseScheduleVP:
def __init__(
self,
schedule='discrete',
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
self.schedule = schedule
if schedule == 'discrete':
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
self.log_alpha_array = log_alphas.reshape((1, -1,))
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
self.schedule = schedule
if schedule == 'cosine':
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == 'discrete':
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
elif self.schedule == 'linear':
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine':
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == 'linear':
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == 'discrete':
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
return t.reshape((-1,))
else:
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == 'discrete':
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, None, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
elif guidance_type == "classifier-free":
c, uc = reconstruct_conds(condition, unconditional_condition, state.sampling_step)
if guidance_scale == 1. or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=c)
else:
noise = noise_pred_fn(x, t_continuous, cond=c)
noise_uncond = noise_pred_fn(x, t_continuous, cond=uc)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class UniPC:
def __init__(
self,
model_fn,
noise_schedule,
predict_x0=True,
thresholding=False,
max_val=1.,
variant='bh1'
):
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
self.model = model_fn
self.noise_schedule = noise_schedule
self.variant = variant
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def unipc_encode(self, x, t, noise=None):
"""
Encodes a latent determined by noise input and a given timestep.
"""
noise = torch.randn_like(x) if noise is None else noise
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (expand_dims(sigma_t, dims) * noise) + expand_dims(alpha_t, dims) * x
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
"""
if skip_type == 'logSNR':
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3,] * (K - 2) + [2, 1]
elif steps % 3 == 1:
orders = [3,] * (K - 1) + [1]
else:
orders = [3,] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2,] * K
else:
K = steps // 2 + 1
orders = [2,] * (K - 1) + [1]
elif order == 1:
K = steps
orders = [1,] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == 'logSNR':
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
if len(t.shape) == 0:
t = t.view(-1)
if 'bh' in self.variant:
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
else:
assert self.variant == 'vary_coeff'
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_t = ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = (lambda_prev_i - lambda_prev_0) / h
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
K = len(rks)
# build C matrix
C = []
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
C_inv_p = torch.linalg.inv(C[:-1, :-1])
A_p = C_inv_p
if use_corrector:
#print('using corrector')
C_inv = torch.linalg.inv(C)
A_c = C_inv
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh)
h_phi_ks = []
factorial_k = 1
h_phi_k = h_phi_1
for k in range(1, K + 2):
h_phi_ks.append(h_phi_k)
h_phi_k = h_phi_k / hh - 1 / factorial_k
factorial_k *= (k + 1)
model_t = None
if self.predict_x0:
x_t_ = (
sigma_t / sigma_prev_0 * x
- alpha_t * h_phi_1 * model_prev_0
)
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = (model_t - model_prev_0)
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
else:
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
x_t_ = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * h_phi_1) * model_prev_0
)
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = (model_t - model_prev_0)
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
return x_t, model_t
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns = self.noise_schedule
assert order <= len(model_prev_list)
dims = x.dim()
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
R = []
b = []
hh = -h[0] if self.predict_x0 else h[0]
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.variant == 'bh1':
B_h = hh
elif self.variant == 'bh2':
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if len(D1s.shape) > 5:
D1s = rearrange(D1s, 'b k c f h w -> (b f) k c h w')
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_corrector:
#print('using corrector')
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
else:
rhos_c = torch.linalg.solve(R, b)
model_t = None
if self.predict_x0:
x_t_ = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
pred_res = repeat(pred_res, 'f c h w -> b c f h w', b=x.shape[0])
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
corr_res = repeat(corr_res, 'f c h w -> b c f h w', b=x.shape[0])
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
pred_res = repeat(pred_res, 'f c h w -> b c f h w', b=x.shape[0])
else:
pred_res = 0
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
corr_res = repeat(corr_res, 'f c h w -> b c f h w', b=x.shape[0])
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
return x_t, model_t
def handle_callback(self, callback):
if callback is not None:
callback()
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, initial_corrector=True, callback=None
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == 'multistep':
assert steps >= order
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=initial_corrector)
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
self.handle_callback(callback)
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
self.handle_callback(callback)
else:
raise NotImplementedError()
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,)*(dims - 1)]

View File

@ -4,6 +4,7 @@
import gradio as gr
from types import SimpleNamespace
from t2v_helpers.video_audio_utils import find_ffmpeg_binary
from samplers.samplers_common import available_samplers
import os
from modules.shared import opts
@ -32,27 +33,34 @@ ModelScope:
i1_store_t2v = f"<p style=\"text-align:center;font-weight:bold;margin-bottom:0em\">text2video extension for auto1111 — version 1.2b. The video will be shown below this label when ready</p>"
def enable_sampler_dropdown(model_type):
is_visible = model_type == "ModelScope"
return gr.update(visible=is_visible)
def setup_common_values(mode, d):
with gr.Row(elem_id=f'{mode}_prompt_toprow'):
prompt = gr.Textbox(label='Prompt', lines=3, interactive=True, elem_id=f"{mode}_prompt", placeholder="Enter your prompt here...")
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
n_prompt = gr.Textbox(label='Negative prompt', lines=2, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
with gr.Row():
sampler = gr.Dropdown(label="Sampling method (ModelScope)", choices=[x.name for x in available_samplers], value=available_samplers[0].name, elem_id="model-sampler", visible=True)
steps = gr.Slider(label='Steps', minimum=1, maximum=100, step=1, value=d.steps)
with gr.Row():
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale)
with gr.Row():
width = gr.Slider(label='Width', minimum=64, maximum=1024, step=64, value=d.width)
height = gr.Slider(label='Height', minimum=64, maximum=1024, step=64, value=d.height)
with gr.Row():
seed = gr.Number(label='Seed', value = d.seed, Interactive = True, precision=0)
eta = gr.Number(label="ETA", value=d.eta, interactive=True)
eta = gr.Number(label="ETA (DDIM Only)", value=d.eta, interactive=True)
with gr.Row():
gr.Markdown('256x256 Benchmarks: 24 frames peak at 5.7 GBs of VRAM and 125 frames peak at 11.5 GBs with Torch2 installed')
with gr.Row():
frames = gr.Slider(label="Frames", value=d.frames, minimum=2, maximum=250, step=1, interactive=True, precision=0)
batch_count = gr.Slider(label="Batch count", value=d.batch_count, minimum=1, maximum=100, step=1, interactive=True)
return prompt, n_prompt, steps, seed, cfg_scale, width, height, eta, frames, batch_count
return prompt, n_prompt, sampler, steps, seed, cfg_scale, width, height, eta, frames, batch_count
refresh_symbol = '\U0001f504' # 🔄
class ToolButton(gr.Button, gr.components.FormComponent):
@ -82,11 +90,12 @@ def setup_text2video_settings_dictionary():
do_vid2vid = gr.State(value=0)
with gr.Tab('txt2vid') as tab_txt2vid:
# TODO: make it how it's done in Deforum/WebUI, so we won't have to track individual vars
prompt, n_prompt, steps, seed, cfg_scale, width, height, eta, frames, batch_count = setup_common_values('txt2vid', d)
prompt, n_prompt, sampler, steps, seed, cfg_scale, width, height, eta, frames, batch_count = setup_common_values('txt2vid', d)
model_type.change(fn=enable_sampler_dropdown, inputs=[model_type], outputs=[sampler])
with gr.Accordion('img2vid', open=False):
inpainting_image = gr.File(label="Inpainting image", interactive=True, file_count="single", file_types=["image"], elem_id="inpainting_chosen_file")
# TODO: should be tied to the total frame count dynamically
inpainting_frames=gr.Slider(label='inpainting frames',value=d.inpainting_frames,minimum=0, maximum=24, step=1)
inpainting_frames=gr.Slider(label='inpainting frames',value=d.inpainting_frames,minimum=0, maximum=250, step=1)
with gr.Row():
gr.Markdown('''`inpainting frames` is the number of frames inpainting is applied to (counting from the beginning)
@ -99,11 +108,6 @@ To *loop it back*, set the weight to 0 for the first and for the last frame
Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
with gr.Row():
inpainting_weights = gr.Textbox(label="Inpainting weights", value=d.inpainting_weights, interactive=True)
def update_max_inp_frames(f, i_frames): # Show video
return gr.update(value=min(f, i_frames), maximum=f, visible=True)
frames.change(fn=update_max_inp_frames, inputs=[frames, inpainting_frames], outputs=[inpainting_frames])
with gr.Tab('vid2vid') as tab_vid2vid:
with gr.Row():
gr.HTML('Put your video here')
@ -114,15 +118,11 @@ Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
with gr.Row():
vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
# TODO: here too
prompt_v, n_prompt_v, steps_v, seed_v, cfg_scale_v, width_v, height_v, eta_v, frames_v, batch_count_v = setup_common_values('vid2vid', d)
prompt_v, n_prompt_v, sampler_v, steps_v, seed_v, cfg_scale_v, width_v, height_v, eta_v, frames_v, batch_count_v = setup_common_values('vid2vid', d)
model_type.change(fn=enable_sampler_dropdown, inputs=[model_type], outputs=[sampler_v])
with gr.Row():
strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
vid2vid_startFrame=gr.Slider(label='vid2vid start frame',value=d.vid2vid_startFrame, minimum=0, maximum=23)
def update_max_vid_frames(v2v_frames, sFrame): # Show video
return gr.update(value=min(sFrame, v2v_frames-1), maximum=v2v_frames-1, visible=True)
frames_v.change(fn=update_max_vid_frames, inputs=[frames_v, vid2vid_startFrame], outputs=[vid2vid_startFrame])
vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
tab_txt2vid.select(fn=lambda: 0, inputs=[], outputs=[do_vid2vid])
tab_vid2vid.select(fn=lambda: 1, inputs=[], outputs=[do_vid2vid])
@ -148,7 +148,7 @@ Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
t2v_video_args_names = str('skip_video_creation, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, fps, add_soundtrack, soundtrack_path').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
common_values_names = str('''prompt, n_prompt, steps, frames, seed, cfg_scale, width, height, eta, batch_count''').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
common_values_names = str('''prompt, n_prompt, sampler, steps, frames, seed, cfg_scale, width, height, eta, batch_count''').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
v2v_values_names = str('''
do_vid2vid, vid2vid_frames, vid2vid_frames_path, strength,vid2vid_startFrame,
@ -199,6 +199,7 @@ def T2VArgs():
vid2vid_startFrame = 0
inpainting_weights = '0:(t/max_i_f), "max_i_f":(1)' # linear growth weights (as they used to be in the original variant)
inpainting_frames = 0
sampler = "DDIM"
return locals()
def T2VArgs_sanity_check(t2v_args):
@ -219,6 +220,8 @@ def T2VArgs_sanity_check(t2v_args):
raise ValueError('vid2vid start frame cannot be greater than the number of frames!')
if t2v_args.inpainting_frames < 0 or t2v_args.inpainting_frames > t2v_args.frames:
raise ValueError('inpainting frames count should lie between 0 and the frames number!')
if not any([x.name == t2v_args.sampler for x in available_samplers]):
raise ValueError("Sampler does not exist.")
except Exception as e:
print(t2v_args)
raise e

View File

@ -1,5 +1,6 @@
# Copyright (C) 2023 by Artem Khrapov (kabachuha)
# Read LICENSE for usage terms.
from modules.prompt_parser import reconstruct_cond_batch
def get_t2v_version():
from modules import extensions as mext
@ -9,4 +10,9 @@ def get_t2v_version():
return ext.version
return "Unknown"
except:
return "Unknown"
return "Unknown"
def reconstruct_conds(cond, uncond, step):
c = reconstruct_cond_batch(cond, step)
uc = reconstruct_cond_batch(uncond, step)
return c, uc

View File

@ -8,44 +8,36 @@ from videocrafter.process_videocrafter import process_videocrafter
from modules.shared import opts
from .error_hardcode import get_error
from modules import lowvram, devices, sd_hijack
import logging
import logging
import gc
import t2v_helpers.args as t2v_helpers_args
def run(*args):
dataurl = get_error()
vids_pack = [dataurl]
component_names = t2v_helpers_args.get_component_names()
# api check
affected_args = args[2:] if len(args) > 36 else args
# TODO: change to i+2 when we will add the progress bar
args_dict = {
component_names[i]: affected_args[i] for i in range(0, len(component_names))
}
model_type = args_dict["model_type"]
t2v_helpers_args.i1_store_t2v = f'<p style="font-weight:bold;margin-bottom:0em">text2video extension for auto1111 — version 1.2b </p><video controls loop><source src="{dataurl}" type="video/mp4"></video>'
keep_pipe_in_vram = (
opts.data.get("modelscope_deforum_keep_model_in_vram")
if opts.data is not None
and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None
else "None"
)
args_dict = {component_names[i]: args[i+2] for i in range(0, len(component_names))}
model_type = args_dict['model_type']
t2v_helpers_args.i1_store_t2v = f'<p style=\"font-weight:bold;margin-bottom:0em\">text2video extension for auto1111 — version 1.2b </p><video controls loop><source src="{dataurl}" type="video/mp4"></video>'
keep_pipe_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None'
try:
print("text2video — The model selected is: ", args_dict["model_type"])
if model_type == "ModelScope":
print('text2video — The model selected is: ', args_dict['model_type'])
if model_type == 'ModelScope':
vids_pack = process_modelscope(args_dict)
elif model_type == "VideoCrafter (WIP)":
elif model_type == 'VideoCrafter (WIP)':
vids_pack = process_videocrafter(args_dict)
else:
raise NotImplementedError(f"Unknown model type: {model_type}")
except Exception as e:
traceback.print_exc()
print("Exception occurred:", e)
print('Exception occurred:', e)
finally:
# optionally store pipe in global between runs, if not, remove it
if keep_pipe_in_vram == "None":
#optionally store pipe in global between runs, if not, remove it
if keep_pipe_in_vram == 'None':
pm.pipe = None
devices.torch_gc()
gc.collect()
return vids_pack