Merge branch 'main' into model_storage
commit
412909b8eb
|
|
@ -67,7 +67,7 @@ def t2v_api(_, app: FastAPI):
|
|||
return JSONResponse(content={"version": get_t2v_version()})
|
||||
|
||||
@app.post("/t2v/run")
|
||||
async def t2v_run(prompt: str, n_prompt: Union[str, None] = None, steps: Union[int, None] = None, frames: Union[int, None] = None, seed: Union[int, None] = None, \
|
||||
async def t2v_run(prompt: str, n_prompt: Union[str, None] = None, sampler: Union[str, None] = None, steps: Union[int, None] = None, frames: Union[int, None] = None, seed: Union[int, None] = None, \
|
||||
cfg_scale: Union[int, None] = None, width: Union[int, None] = None, height: Union[int, None] = None, eta: Union[float, None] = None, batch_count: Union[int, None] = None, \
|
||||
do_vid2vid:bool = False, vid2vid_input: Union[UploadFile, None] = None,strength: Union[float, None] = None,vid2vid_startFrame: Union[int, None] = None, \
|
||||
inpainting_image: Union[UploadFile, None] = None, inpainting_frames: Union[int, None] = None, inpainting_weights: Union[str, None] = None, \
|
||||
|
|
@ -131,6 +131,7 @@ def t2v_api(_, app: FastAPI):
|
|||
|
||||
d.prompt,#prompt
|
||||
d.n_prompt,#n_prompt
|
||||
d.sampler,#sampler
|
||||
d.steps,#steps
|
||||
d.frames,#frames
|
||||
d.seed,#seed
|
||||
|
|
@ -143,6 +144,7 @@ def t2v_api(_, app: FastAPI):
|
|||
# The same, but for vid2vid. Will deduplicate later
|
||||
d.prompt,#prompt
|
||||
d.n_prompt,#n_prompt
|
||||
d.sampler,#sampler
|
||||
d.steps,#steps
|
||||
d.frames,#frames
|
||||
d.seed,#seed
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ def process_modelscope(args_dict):
|
|||
state.job_count = args.batch_count
|
||||
|
||||
for batch in pbar:
|
||||
state.job_no = batch + 1
|
||||
state.job_no = batch
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ def process_modelscope(args_dict):
|
|||
args.strength = 1
|
||||
|
||||
samples, _ = pipe.infer(args.prompt, args.n_prompt, args.steps, args.frames, args.seed + batch if args.seed != -1 else -1, args.cfg_scale,
|
||||
args.width, args.height, args.eta, cpu_vae, device, latents, skip_steps=skip_steps, mask=mask)
|
||||
args.width, args.height, args.eta, cpu_vae, device, latents, strength=args.strength, skip_steps=skip_steps, mask=mask, is_vid2vid=args.do_vid2vid, sampler=args.sampler)
|
||||
|
||||
if batch > 0:
|
||||
outdir_current = os.path.join(get_outdir(), f"{init_timestring}_{batch}")
|
||||
|
|
|
|||
|
|
@ -23,16 +23,19 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from os import path as osp
|
||||
from modules.shared import opts
|
||||
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from modules.prompt_parser import reconstruct_cond_batch
|
||||
from modules.shared import state
|
||||
from modules.sd_samplers_common import InterruptedException
|
||||
|
||||
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
__all__ = ['UNetSD']
|
||||
|
||||
|
|
@ -112,7 +115,8 @@ class UNetSD(nn.Module):
|
|||
use_checkpoint=False,
|
||||
use_image_dataset=False,
|
||||
use_fps_condition=False,
|
||||
use_sim_mask=False):
|
||||
use_sim_mask=False,
|
||||
parameterization="eps"):
|
||||
embed_dim = dim * 4
|
||||
num_heads = num_heads if num_heads else dim // 32
|
||||
super(UNetSD, self).__init__()
|
||||
|
|
@ -135,6 +139,8 @@ class UNetSD(nn.Module):
|
|||
self.use_image_dataset = use_image_dataset
|
||||
self.use_fps_condition = use_fps_condition
|
||||
self.use_sim_mask = use_sim_mask
|
||||
self.parameterization = parameterization
|
||||
self.v_posterior = 0
|
||||
use_linear_in_temporal = False
|
||||
transformer_depth = 1
|
||||
disabled_sa = False
|
||||
|
|
@ -319,6 +325,64 @@ class UNetSD(nn.Module):
|
|||
# zero out the last layer params
|
||||
nn.init.zeros_(self.out[-1].weight)
|
||||
|
||||
# Taken from DDPM
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
|
||||
if exists(given_betas):
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
||||
1. - alphas_cumprod) + self.v_posterior * betas
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||
self.register_buffer('posterior_mean_coef1', to_torch(
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
||||
self.register_buffer('posterior_mean_coef2', to_torch(
|
||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas ** 2 / (
|
||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
||||
elif self.parameterization == "x0":
|
||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
||||
elif self.parameterization == "v":
|
||||
lvlb_weights = torch.ones_like(self.betas ** 2 / (
|
||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
|
||||
else:
|
||||
raise NotImplementedError("mu not supported")
|
||||
lvlb_weights[0] = lvlb_weights[1]
|
||||
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
||||
assert not torch.isnan(self.lvlb_weights).all()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
|
|
@ -488,7 +552,7 @@ class CrossAttention(nn.Module):
|
|||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
|
||||
|
||||
if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
import xformers
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
|
|
@ -1266,8 +1330,8 @@ class GaussianDiffusion(object):
|
|||
r"""Distribution of p(x_{t-1} | x_t).
|
||||
"""
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
if guide_scale is None or guide_scale == 1:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
|
|
|
|||
|
|
@ -19,8 +19,11 @@ import torch.cuda.amp as amp
|
|||
from einops import rearrange
|
||||
import cv2
|
||||
from modelscope.t2v_model import UNetSD, AutoencoderKL, GaussianDiffusion, beta_schedule
|
||||
from modules import devices
|
||||
from modules import devices, shared
|
||||
from modules import prompt_parser
|
||||
from samplers.uni_pc.sampler import UniPCSampler
|
||||
from samplers.samplers_common import Txt2VideoSampler
|
||||
from samplers.samplers_common import available_samplers
|
||||
|
||||
__all__ = ['TextToVideoSynthesis']
|
||||
|
||||
|
|
@ -86,6 +89,7 @@ class TextToVideoSynthesis():
|
|||
num_res_blocks=cfg['unet_res_blocks'],
|
||||
attn_scales=cfg['unet_attn_scales'],
|
||||
dropout=cfg['unet_dropout'],
|
||||
parameterization=cfg['mean_type'],
|
||||
temporal_attention=cfg['temporal_attention'])
|
||||
self.sd_model.load_state_dict(
|
||||
torch.load(
|
||||
|
|
@ -97,20 +101,17 @@ class TextToVideoSynthesis():
|
|||
self.sd_model.eval()
|
||||
if not devices.has_mps() or torch.cuda.is_available() == True:
|
||||
self.sd_model.half()
|
||||
|
||||
|
||||
# Initialize diffusion
|
||||
betas = beta_schedule(
|
||||
'linear_sd',
|
||||
cfg['num_timesteps'],
|
||||
init_beta=0.00085,
|
||||
last_beta=0.0120)
|
||||
self.diffusion = GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=cfg['mean_type'],
|
||||
var_type=cfg['var_type'],
|
||||
loss_type=cfg['loss_type'],
|
||||
rescale_timesteps=False)
|
||||
|
||||
|
||||
self.sd_model.register_schedule(given_betas=betas.numpy())
|
||||
self.diffusion = Txt2VideoSampler(self.sd_model, shared.device, betas=betas)
|
||||
|
||||
# Initialize autoencoder
|
||||
ddconfig = {
|
||||
'double_z': True,
|
||||
|
|
@ -192,7 +193,26 @@ class TextToVideoSynthesis():
|
|||
return out
|
||||
|
||||
# @torch.compile()
|
||||
def infer(self, prompt, n_prompt, steps, frames, seed, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device=torch.device('cpu'), latents=None, skip_steps=0,strength=0,mask=None):
|
||||
def infer(
|
||||
self,
|
||||
prompt,
|
||||
n_prompt,
|
||||
steps,
|
||||
frames,
|
||||
seed,
|
||||
scale,
|
||||
width=256,
|
||||
height=256,
|
||||
eta=0.0,
|
||||
cpu_vae='GPU (half precision)',
|
||||
device=torch.device('cpu'),
|
||||
latents=None,
|
||||
skip_steps=0,
|
||||
strength=0,
|
||||
mask=None,
|
||||
is_vid2vid=False,
|
||||
sampler=available_samplers[0].name
|
||||
):
|
||||
vars = locals()
|
||||
vars.pop('self')
|
||||
vars.pop('latents')
|
||||
|
|
@ -227,6 +247,7 @@ class TextToVideoSynthesis():
|
|||
self.device = device
|
||||
self.clip_encoder.to(self.device)
|
||||
self.clip_encoder.device = self.device
|
||||
steps = steps - skip_steps
|
||||
c, uc = self.preprocess(prompt, n_prompt, steps)
|
||||
if self.keep_in_vram != "All":
|
||||
self.clip_encoder.to("cpu")
|
||||
|
|
@ -236,36 +257,37 @@ class TextToVideoSynthesis():
|
|||
latents=latents.half() if 'half precision' in cpu_vae and latents is not None else latents
|
||||
|
||||
# synthesis
|
||||
strength = None if strength == 0.0 else strength
|
||||
strength = None if (strength == 0.0 and not is_vid2vid) else strength
|
||||
with torch.no_grad():
|
||||
num_sample = 1
|
||||
max_frames = frames
|
||||
latent_h, latent_w = height // 8, width // 8
|
||||
self.sd_model.to(self.device)
|
||||
if latents == None:
|
||||
self.noise_gen.manual_seed(seed)
|
||||
latents = torch.randn(num_sample, 4, max_frames, latent_h,
|
||||
latent_w, generator=self.noise_gen).to(
|
||||
self.device)
|
||||
else:
|
||||
latents.to(self.device)
|
||||
|
||||
print("latents", latents.shape, torch.mean(
|
||||
latents), torch.std(latents))
|
||||
|
||||
channels = 4
|
||||
max_frames= frames
|
||||
latents, noise, shape = self.diffusion.get_noise(
|
||||
num_sample,
|
||||
channels,
|
||||
max_frames,
|
||||
height,
|
||||
width,
|
||||
seed=seed,
|
||||
latents=latents
|
||||
)
|
||||
with amp.autocast(enabled=True):
|
||||
self.sd_model.to(self.device)
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=latents, # shape: b c f h w
|
||||
model=self.sd_model,
|
||||
c=c,
|
||||
uc=uc,
|
||||
num_sample=1,
|
||||
guide_scale=scale,
|
||||
ddim_timesteps=steps,
|
||||
self.diffusion.get_sampler(sampler, return_sampler=False)
|
||||
|
||||
x0 = self.diffusion.sample_loop(
|
||||
steps=steps,
|
||||
strength=strength,
|
||||
eta=eta,
|
||||
percentile=strength,
|
||||
skip_steps=skip_steps,
|
||||
conditioning=c,
|
||||
unconditional_conditioning=uc,
|
||||
batch_size=num_sample,
|
||||
guidance_scale=scale,
|
||||
latents=latents,
|
||||
shape=shape,
|
||||
noise=noise,
|
||||
is_vid2vid=is_vid2vid,
|
||||
sampler_name=sampler,
|
||||
mask=mask
|
||||
)
|
||||
|
||||
|
|
@ -435,3 +457,4 @@ def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
|||
images = [(image.numpy() * 255).astype('uint8')
|
||||
for image in images] # f h w c
|
||||
return images
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,300 @@
|
|||
import torch
|
||||
from modelscope.t2v_model import _i
|
||||
from t2v_helpers.general_utils import reconstruct_conds
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
r""" Diffusion Model for DDIM.
|
||||
"Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon.
|
||||
See https://arxiv.org/abs/2010.02502
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
epsilon=1e-12,
|
||||
rescale_timesteps=False,
|
||||
**kwargs):
|
||||
|
||||
# check input
|
||||
self.check_input_vars(betas, mean_type, var_type, loss_type)
|
||||
|
||||
self.model = model
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.epsilon = epsilon
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:],alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
|
||||
def check_input_vars(self, betas, mean_type, var_type, loss_type):
|
||||
mean_types = ['x0', 'x_{t-1}', 'eps']
|
||||
var_types = ['learned', 'learned_range', 'fixed_large', 'fixed_small']
|
||||
loss_types = ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
|
||||
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in mean_types
|
||||
assert var_type in var_types
|
||||
assert loss_type in loss_types
|
||||
|
||||
def validate_model_kwargs(self, model_kwargs):
|
||||
"""
|
||||
Use the original implementation of passing model kwargs to the model.
|
||||
eg: model_kwargs=[{'y':c_i}, {'y':uc_i,}]
|
||||
"""
|
||||
if len(model_kwargs) > 0:
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
|
||||
def get_time_steps(self, ddim_timesteps, batch_size=1, step=None):
|
||||
b = batch_size
|
||||
|
||||
# Get thhe full timestep range
|
||||
arange_steps = (1 + torch.arange(0, self.num_timesteps, ddim_timesteps))
|
||||
steps = arange_steps.clamp(0, self.num_timesteps - 1)
|
||||
timesteps = steps.flip(0).to(self.model.device)
|
||||
|
||||
if step is not None:
|
||||
# Get the current timestep during a sample loop
|
||||
timesteps = torch.full((b, ), timesteps[step], dtype=torch.long, device=self.model.device)
|
||||
|
||||
return timesteps
|
||||
|
||||
def add_noise(self, xt, noise, t):
|
||||
noisy_sample = self.sqrt_alphas_cumprod[t.cpu()].to(self.model.device) * \
|
||||
xt + noise * self.sqrt_one_minus_alphas_cumprod[t.cpu()].to(self.model.device)
|
||||
|
||||
return noisy_sample
|
||||
|
||||
def get_dim(self, y_out):
|
||||
is_fixed = self.var_type.startswith('fixed')
|
||||
return y_out.size(1) if is_fixed else y_out.size(1) // 2
|
||||
|
||||
def fixed_small_variance(self, xt, t):
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
return var, log_var
|
||||
|
||||
def mean_x0(self, xt, t, x_out):
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * x_out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
return x0, mu
|
||||
|
||||
def restrict_range_x0(self, percentile, x0, clamp=False):
|
||||
if not clamp:
|
||||
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
||||
s = torch.quantile(x0.flatten(1).abs(), percentile,dim=1)
|
||||
s.clamp_(1.0).view(-1, 1, 1, 1)
|
||||
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
else:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
|
||||
return x0
|
||||
|
||||
def is_unconditional(self, guide_scale):
|
||||
return guide_scale is None or guide_scale == 1
|
||||
|
||||
def do_classifier_guidance(self, y_out, u_out, guidance_scale):
|
||||
"""
|
||||
y_out: Condition
|
||||
u_out: Unconditional
|
||||
"""
|
||||
dim = self.get_dim(y_out)
|
||||
a = u_out[:, :dim]
|
||||
b = guidance_scale * (y_out[:, :dim] - u_out[:, :dim])
|
||||
c = y_out[:, dim:]
|
||||
out = torch.cat([a + b, c], dim=1)
|
||||
|
||||
return out
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
conditioning=None,
|
||||
unconditional_conditioning=None,
|
||||
only_x0=True,
|
||||
**kwargs):
|
||||
r"""Distribution of p(x_{t-1} | x_t)."""
|
||||
|
||||
# predict distribution
|
||||
if self.is_unconditional(guide_scale):
|
||||
out = self.model(xt, self._scale_timesteps(t), conditioning)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
if model_kwargs != {}:
|
||||
self.validate_model_kwargs(model_kwargs)
|
||||
conditioning = model_kwargs[0]
|
||||
unconditional_conditioning = model_kwargs[1]
|
||||
|
||||
y_out = self.model(xt, self._scale_timesteps(t), conditioning)
|
||||
u_out = self.model(xt, self._scale_timesteps(t), unconditional_conditioning)
|
||||
|
||||
out = self.do_classifier_guidance(y_out, u_out, guide_scale)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'fixed_small':
|
||||
var, log_var = self.fixed_small_variance(xt, t)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'eps':
|
||||
x0, mu = self.mean_x0(xt, t, out)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
x0 = self.restrict_range_x0(percentile, x0)
|
||||
elif clamp is not None:
|
||||
x0 = self.restrict_range_x0(percentile, x0, clamp=True)
|
||||
|
||||
if only_x0:
|
||||
return x0
|
||||
else:
|
||||
return mu, var, log_var, x0
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||
"""
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
|
||||
def get_eps(self, xt, x0, t, alpha, condition_fn, model_kwargs={}):
|
||||
# x0 -> eps
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
|
||||
if condition_fn is not None:
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
return eps, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
x_T=None,
|
||||
S=5,
|
||||
shape=None,
|
||||
conditioning=None,
|
||||
unconditional_conditioning=None,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
unconditional_guidance_scale=None,
|
||||
eta=0.0,
|
||||
callback=None,
|
||||
mask=None,
|
||||
**kwargs):
|
||||
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
|
||||
# Shape must exist to sample
|
||||
if shape is None and x_T is None:
|
||||
assert "Shape must exists to sample from noise"
|
||||
|
||||
# Assign variables for sampling
|
||||
steps = S
|
||||
stride = self.num_timesteps // steps
|
||||
guide_scale = unconditional_guidance_scale
|
||||
original_latents = None
|
||||
|
||||
if x_T is None:
|
||||
xt = torch.randn(shape, device=self.model.device)
|
||||
else:
|
||||
xt = x_T.clone()
|
||||
original_latents = xt
|
||||
|
||||
timesteps = self.get_time_steps(stride, xt.shape[0])
|
||||
|
||||
for step in range(0, steps):
|
||||
c, uc = reconstruct_conds(conditioning, unconditional_conditioning, step)
|
||||
t = self.get_time_steps(stride, xt.shape[0], step=step)
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
x0 = self.p_mean_variance(
|
||||
xt,
|
||||
t,
|
||||
model_kwargs,
|
||||
clamp,
|
||||
percentile,
|
||||
guide_scale,
|
||||
conditioning=c,
|
||||
unconditional_conditioning=uc,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
|
||||
eps, x0 = self.get_eps(xt, x0, t, alphas, condition_fn)
|
||||
|
||||
a = (1 - alphas_prev) / (1 - alphas)
|
||||
b = (1 - alphas / alphas_prev)
|
||||
sigmas = eta * torch.sqrt(a * b)
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
xt = xt_1
|
||||
|
||||
if hasattr(self, 'inpaint_masking') and mask is not None:
|
||||
add_noise_args = {
|
||||
"xt":xt,
|
||||
"noise": torch.randn_like(xt),
|
||||
"t": timesteps[(step - 1) + 1]
|
||||
}
|
||||
self.inpaint_masking(xt, step, steps, mask, self.add_noise, add_noise_args)
|
||||
|
||||
if callback is not None:
|
||||
callback(step)
|
||||
|
||||
return xt
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from modules import shared
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
from t2v_helpers.general_utils import reconstruct_conds
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
B, C, F, H, W = shape
|
||||
size = (B, C, F, H, W)
|
||||
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule
|
||||
)
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
|
||||
device = shared.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
|
||||
#print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, disable=True)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
c, uc = reconstruct_conds(cond, unconditional_conditioning, step)
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
#if mask is not None:
|
||||
# assert x0 is not None
|
||||
# img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
# img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs, _ = self.p_sample_ddim(img, c, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
|
||||
img = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model(x, t, c)
|
||||
else:
|
||||
noise = self.model(x, t, c)
|
||||
noise_uncond = self.model(x, t, unconditional_conditioning)
|
||||
|
||||
model_output = noise_uncond + unconditional_guidance_scale * (noise - noise_uncond)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback: callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None, *args, **kwargs):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps, disable=True)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
c, uc = reconstruct_conds(cond, unconditional_conditioning, step)
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, c, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
import torch
|
||||
from samplers.ddim.sampler import DDIMSampler
|
||||
from samplers.ddim.gaussian_sampler import GaussianDiffusion
|
||||
from samplers.uni_pc.sampler import UniPCSampler
|
||||
from tqdm import tqdm
|
||||
from modules.shared import state
|
||||
from modules.sd_samplers_common import InterruptedException
|
||||
|
||||
def get_height_width(h, w, divisor):
|
||||
return h // divisor, w // divisor
|
||||
|
||||
def get_tensor_shape(batch_size, channels, frames, h, w, latents=None):
|
||||
if latents is None:
|
||||
return (batch_size, channels, frames, h, w)
|
||||
return latents.shape
|
||||
|
||||
def inpaint_masking(xt, step, steps, mask, add_noise_cb, noise_cb_args):
|
||||
if mask is not None and step < steps - 1:
|
||||
|
||||
#convert mask to 0,1 valued based on step
|
||||
v = (steps - step - 1) / steps
|
||||
binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask))
|
||||
|
||||
noise_to_add = add_noise_cb(**noise_cb_args)
|
||||
to_inpaint = noise_to_add
|
||||
xt = to_inpaint * (1 - binary_mask) + xt * binary_mask
|
||||
|
||||
class SamplerStepCallback(object):
|
||||
def __init__(self, sampler_name: str, total_steps: int):
|
||||
self.sampler_name = sampler_name
|
||||
self.total_steps = total_steps
|
||||
self.current_step = 0
|
||||
self.progress_bar = tqdm(desc=self.progress_msg(sampler_name, total_steps), total=total_steps)
|
||||
|
||||
def progress_msg(self, name, total_steps=None):
|
||||
total_steps = total_steps if total_steps is not None else self.total_steps
|
||||
state.sampling_steps = total_steps
|
||||
return f"Sampling Using {name} for {total_steps} steps."
|
||||
|
||||
def set_webui_step(self, step):
|
||||
state.sampling_step = step
|
||||
|
||||
def is_finished(self, step):
|
||||
if step >= self.total_steps:
|
||||
self.progress_bar.close()
|
||||
self.current_step = 0
|
||||
|
||||
def interrupt(self):
|
||||
return state.interrupted or state.skipped
|
||||
|
||||
def cancel(self):
|
||||
raise InterruptedException
|
||||
|
||||
def update(self, step):
|
||||
self.set_webui_step(step)
|
||||
|
||||
if self.interrupt():
|
||||
self.cancel()
|
||||
|
||||
self.progress_bar.set_description(self.progress_msg(self.sampler_name))
|
||||
self.progress_bar.update(1)
|
||||
|
||||
self.is_finished(step)
|
||||
|
||||
def __call__(self,*args, **kwargs):
|
||||
self.current_step += 1
|
||||
step = self.current_step
|
||||
|
||||
self.update(step)
|
||||
|
||||
class SamplerBase(object):
|
||||
def __init__(self, name: str, Sampler, frame_inpaint_support=False):
|
||||
self.name = name
|
||||
self.Sampler = Sampler
|
||||
self.frame_inpaint_support = frame_inpaint_support
|
||||
|
||||
def register_buffers_to_model(self, sd_model, betas, device):
|
||||
self.alphas = 1. - betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
setattr(sd_model, 'device', device)
|
||||
setattr(sd_model, 'betas', betas)
|
||||
setattr(sd_model, 'alphas_cumprod', self.alphas_cumprod)
|
||||
|
||||
def init_sampler(self, sd_model, betas, device, **kwargs):
|
||||
self.register_buffers_to_model(sd_model, betas, device)
|
||||
return self.Sampler(sd_model, betas=betas, **kwargs)
|
||||
|
||||
available_samplers = [
|
||||
SamplerBase("DDIM_Gaussian", GaussianDiffusion, True),
|
||||
SamplerBase("DDIM", DDIMSampler),
|
||||
SamplerBase("UniPC", UniPCSampler),
|
||||
]
|
||||
|
||||
class Txt2VideoSampler(object):
|
||||
def __init__(self, sd_model, device, betas=None, sampler_name="UniPC"):
|
||||
self.sd_model = sd_model
|
||||
self.device = device
|
||||
self.noise_gen = torch.Generator(device='cpu')
|
||||
self.sampler_name = sampler_name
|
||||
self.betas = betas
|
||||
self.sampler = self.get_sampler(sampler_name, betas=self.betas)
|
||||
|
||||
def get_noise(self, num_sample, channels, frames, height, width, latents=None, seed=1):
|
||||
if latents is not None:
|
||||
latents.to(self.device)
|
||||
|
||||
print(f"Using input latents. Shape: {latents.shape}, Mean: {torch.mean(latents)}, Std: {torch.std(latents)}")
|
||||
else:
|
||||
print("Sampling random noise.")
|
||||
|
||||
num_sample = 1
|
||||
max_frames = frames
|
||||
|
||||
latent_h, latent_w = get_height_width(height, width, 8)
|
||||
shape = get_tensor_shape(num_sample, channels, max_frames, latent_h, latent_w, latents)
|
||||
|
||||
self.noise_gen.manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=self.noise_gen).to(self.device)
|
||||
|
||||
return latents, noise, shape
|
||||
|
||||
def encode_latent(self, latent, noise, strength, steps):
|
||||
encoded_latent = None
|
||||
denoise_steps = None
|
||||
|
||||
if hasattr(self.sampler, 'unipc_encode'):
|
||||
encoded_latent = self.sampler.unipc_encode(latent, self.device, strength, steps, noise=noise)
|
||||
|
||||
if hasattr(self.sampler, 'stochastic_encode'):
|
||||
denoise_steps = int(strength * steps)
|
||||
timestep = torch.tensor([denoise_steps] * int(latent.shape[0])).to(self.device)
|
||||
self.sampler.make_schedule(steps)
|
||||
encoded_latent = self.sampler.stochastic_encode(latent, timestep, noise=noise).to(dtype=latent.dtype)
|
||||
self.sampler.sample = self.sampler.decode
|
||||
|
||||
if hasattr(self.sampler, 'add_noise'):
|
||||
denoise_steps = int(strength * steps)
|
||||
timestep = self.sampler.get_time_steps(denoise_steps, latent.shape[0])
|
||||
encoded_latent = self.sampler.add_noise(latent, noise, timestep[0].cpu())
|
||||
|
||||
if encoded_latent is None:
|
||||
assert "Could not find the appropriate function to encode the input latents"
|
||||
|
||||
return encoded_latent, denoise_steps
|
||||
|
||||
def get_sampler(self, sampler_name: str, betas=None, return_sampler=True):
|
||||
betas = betas if betas is not None else self.betas
|
||||
|
||||
for Sampler in available_samplers:
|
||||
if sampler_name == Sampler.name:
|
||||
sampler = Sampler.init_sampler(self.sd_model, betas=betas, device=self.device)
|
||||
|
||||
if Sampler.frame_inpaint_support:
|
||||
setattr(sampler, 'inpaint_masking', inpaint_masking)
|
||||
|
||||
if return_sampler:
|
||||
return sampler
|
||||
else:
|
||||
self.sampler = sampler
|
||||
return
|
||||
|
||||
raise ValueError(f"Sample {sampler_name} does not exist.")
|
||||
|
||||
def sample_loop(
|
||||
self,
|
||||
steps,
|
||||
strength,
|
||||
conditioning,
|
||||
unconditional_conditioning,
|
||||
batch_size,
|
||||
latents=None,
|
||||
shape=None,
|
||||
noise=None,
|
||||
is_vid2vid=False,
|
||||
guidance_scale=1,
|
||||
eta=0,
|
||||
mask=None,
|
||||
sampler_name="DDIM"
|
||||
):
|
||||
denoise_steps = None
|
||||
# Assume that we are adding noise to existing latents (Image, Video, etc.)
|
||||
if latents is not None and is_vid2vid:
|
||||
latents, denoise_steps = self.encode_latent(latents, noise, strength, steps)
|
||||
|
||||
# Create a callback that handles counting each step
|
||||
sampler_callback = SamplerStepCallback(sampler_name, steps)
|
||||
|
||||
# Predict the noise sample
|
||||
x0 = self.sampler.sample(
|
||||
S=steps,
|
||||
conditioning=conditioning,
|
||||
strength=strength,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
batch_size=batch_size,
|
||||
x_T=latents if latents is not None else noise,
|
||||
x_latent=latents,
|
||||
t_start=denoise_steps,
|
||||
unconditional_guidance_scale=guidance_scale,
|
||||
shape=shape,
|
||||
callback=sampler_callback,
|
||||
cond=conditioning,
|
||||
eta=eta,
|
||||
mask=mask
|
||||
)
|
||||
|
||||
return x0
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
|
||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||
|
||||
class UniPCSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def unipc_encode(self, latent, device, strength, steps, noise=None):
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
uni_pc = UniPC(None, ns, predict_x0=True, thresholding=False, variant='bh1')
|
||||
t_0 = 1. / ns.total_N
|
||||
|
||||
timesteps = uni_pc.get_time_steps("time_uniform", strength, t_0, steps, device)
|
||||
timesteps = timesteps[0].expand((latent.shape[0]))
|
||||
|
||||
noisy_latent = uni_pc.unipc_encode(latent, timesteps, noise=noise)
|
||||
return noisy_latent
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
strength=None,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# sampling
|
||||
B, C, F, H, W = shape
|
||||
size = (B, C, F, H, W)
|
||||
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=self.model.device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model(x, t, c),
|
||||
ns,
|
||||
model_type="noise",
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant='bh1')
|
||||
x = uni_pc.sample(
|
||||
img,
|
||||
steps=S,
|
||||
t_start=strength,
|
||||
skip_type="time_uniform",
|
||||
method="multistep",
|
||||
order=3,
|
||||
lower_order_final=True,
|
||||
initial_corrector=True,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
return x.to(self.model.device)
|
||||
|
|
@ -0,0 +1,800 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from einops import rearrange,repeat
|
||||
from modules.shared import state
|
||||
from t2v_helpers.general_utils import reconstruct_conds
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
self,
|
||||
schedule='discrete',
|
||||
betas=None,
|
||||
alphas_cumprod=None,
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
||||
***
|
||||
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
||||
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
||||
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
||||
log_alpha_t = self.marginal_log_mean_coeff(t)
|
||||
sigma_t = self.marginal_std(t)
|
||||
lambda_t = self.marginal_lambda(t)
|
||||
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
||||
t = self.inverse_lambda(lambda_t)
|
||||
===============================================================
|
||||
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
||||
1. For discrete-time DPMs:
|
||||
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
||||
t_i = (i + 1) / N
|
||||
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
||||
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
||||
Args:
|
||||
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
||||
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
||||
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
||||
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
||||
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
||||
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
||||
and
|
||||
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
||||
2. For continuous-time DPMs:
|
||||
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
||||
schedule are the default settings in DDPM and improved-DDPM:
|
||||
Args:
|
||||
beta_min: A `float` number. The smallest beta for the linear schedule.
|
||||
beta_max: A `float` number. The largest beta for the linear schedule.
|
||||
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
||||
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
||||
T: A `float` number. The ending time of the forward process.
|
||||
===============================================================
|
||||
Args:
|
||||
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
===============================================================
|
||||
Example:
|
||||
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
||||
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
# For continuous-time DPMs (VPSDE), linear schedule:
|
||||
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
||||
"""
|
||||
|
||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule == 'discrete':
|
||||
if betas is not None:
|
||||
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||
else:
|
||||
assert alphas_cumprod is not None
|
||||
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||
self.total_N = len(log_alphas)
|
||||
self.T = 1.
|
||||
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
||||
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
||||
else:
|
||||
self.total_N = 1000
|
||||
self.beta_0 = continuous_beta_0
|
||||
self.beta_1 = continuous_beta_1
|
||||
self.cosine_s = 0.008
|
||||
self.cosine_beta_max = 999.
|
||||
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||
self.schedule = schedule
|
||||
if schedule == 'cosine':
|
||||
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||
self.T = 0.9946
|
||||
else:
|
||||
self.T = 1.
|
||||
|
||||
def marginal_log_mean_coeff(self, t):
|
||||
"""
|
||||
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
if self.schedule == 'discrete':
|
||||
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||
elif self.schedule == 'linear':
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == 'cosine':
|
||||
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
|
||||
def marginal_alpha(self, t):
|
||||
"""
|
||||
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||
|
||||
def marginal_std(self, t):
|
||||
"""
|
||||
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||
|
||||
def marginal_lambda(self, t):
|
||||
"""
|
||||
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def inverse_lambda(self, lamb):
|
||||
"""
|
||||
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||
"""
|
||||
if self.schedule == 'linear':
|
||||
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == 'discrete':
|
||||
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||
return t.reshape((-1,))
|
||||
else:
|
||||
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
t = t_fn(log_alpha)
|
||||
return t
|
||||
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
guidance_type="uncond",
|
||||
condition=None,
|
||||
unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs={},
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
||||
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
||||
We support four types of the diffusion model by setting `model_type`:
|
||||
1. "noise": noise prediction model. (Trained by predicting noise).
|
||||
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
||||
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
||||
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
||||
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
noise(x_t, t) = -sigma_t * score(x_t, t)
|
||||
```
|
||||
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
||||
1. "uncond": unconditional sampling by DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
The input `classifier_fn` has the following format:
|
||||
``
|
||||
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
||||
``
|
||||
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
||||
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
||||
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
||||
``
|
||||
def model_fn(x, t_continuous) -> noise:
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
``
|
||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||
===============================================================
|
||||
Args:
|
||||
model: A diffusion model with the corresponding format described above.
|
||||
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
||||
model_type: A `str`. The parameterization type of the diffusion model.
|
||||
"noise" or "x_start" or "v" or "score".
|
||||
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
||||
guidance_type: A `str`. The type of the guidance for sampling.
|
||||
"uncond" or "classifier" or "classifier-free".
|
||||
condition: A pytorch tensor. The condition for the guided sampling.
|
||||
Only used for "classifier" or "classifier-free" guidance type.
|
||||
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
||||
Only used for "classifier-free" guidance type.
|
||||
guidance_scale: A `float`. The scale for the guided sampling.
|
||||
classifier_fn: A classifier function. Only used for the classifier guidance.
|
||||
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
||||
Returns:
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||
For continuous-time DPMs, we just use `t_continuous`.
|
||||
"""
|
||||
if noise_schedule.schedule == 'discrete':
|
||||
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
||||
else:
|
||||
return t_continuous
|
||||
|
||||
def noise_pred_fn(x, t_continuous, cond=None):
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
if cond is None:
|
||||
output = model(x, t_input, None, **model_kwargs)
|
||||
else:
|
||||
output = model(x, t_input, cond, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
||||
elif model_type == "v":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
||||
elif model_type == "score":
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return -expand_dims(sigma_t, dims) * output
|
||||
|
||||
def cond_grad_fn(x, t_input):
|
||||
"""
|
||||
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||
"""
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||
|
||||
def model_fn(x, t_continuous):
|
||||
"""
|
||||
The noise predicition model function that is used for DPM-Solver.
|
||||
"""
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
if guidance_type == "uncond":
|
||||
return noise_pred_fn(x, t_continuous)
|
||||
elif guidance_type == "classifier":
|
||||
assert classifier_fn is not None
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
cond_grad = cond_grad_fn(x, t_input)
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
noise = noise_pred_fn(x, t_continuous)
|
||||
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
||||
elif guidance_type == "classifier-free":
|
||||
c, uc = reconstruct_conds(condition, unconditional_condition, state.sampling_step)
|
||||
if guidance_scale == 1. or unconditional_condition is None:
|
||||
return noise_pred_fn(x, t_continuous, cond=c)
|
||||
else:
|
||||
noise = noise_pred_fn(x, t_continuous, cond=c)
|
||||
noise_uncond = noise_pred_fn(x, t_continuous, cond=uc)
|
||||
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
assert model_type in ["noise", "x_start", "v"]
|
||||
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||
return model_fn
|
||||
|
||||
|
||||
class UniPC:
|
||||
def __init__(
|
||||
self,
|
||||
model_fn,
|
||||
noise_schedule,
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1'
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
self.model = model_fn
|
||||
self.noise_schedule = noise_schedule
|
||||
self.variant = variant
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def noise_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the data prediction model (with thresholding).
|
||||
"""
|
||||
noise = self.noise_prediction_fn(x, t)
|
||||
dims = x.dim()
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||
if self.thresholding:
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def unipc_encode(self, x, t, noise=None):
|
||||
"""
|
||||
Encodes a latent determined by noise input and a given timestep.
|
||||
"""
|
||||
noise = torch.randn_like(x) if noise is None else noise
|
||||
dims = x.dim()
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (expand_dims(sigma_t, dims) * noise) + expand_dims(alpha_t, dims) * x
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
else:
|
||||
return self.noise_prediction_fn(x, t)
|
||||
|
||||
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
||||
"""Compute the intermediate time steps for sampling.
|
||||
"""
|
||||
if skip_type == 'logSNR':
|
||||
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||
elif skip_type == 'time_uniform':
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == 'time_quadratic':
|
||||
t_order = 2
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||
"""
|
||||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [3,] * (K - 2) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [3,] * (K - 1) + [1]
|
||||
else:
|
||||
orders = [3,] * (K - 1) + [2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
K = steps // 2
|
||||
orders = [2,] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [2,] * (K - 1) + [1]
|
||||
elif order == 1:
|
||||
K = steps
|
||||
orders = [1,] * steps
|
||||
else:
|
||||
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||
if skip_type == 'logSNR':
|
||||
# To reproduce the results in DPM-Solver paper
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
||||
else:
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||
return timesteps_outer, orders
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||
if len(t.shape) == 0:
|
||||
t = t.view(-1)
|
||||
if 'bh' in self.variant:
|
||||
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
else:
|
||||
assert self.variant == 'vary_coeff'
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
K = len(rks)
|
||||
# build C matrix
|
||||
C = []
|
||||
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh)
|
||||
h_phi_ks = []
|
||||
factorial_k = 1
|
||||
h_phi_k = h_phi_1
|
||||
for k in range(1, K + 2):
|
||||
h_phi_ks.append(h_phi_k)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||
factorial_k *= (k + 1)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
sigma_t / sigma_prev_0 * x
|
||||
- alpha_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
else:
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
x_t_ = (
|
||||
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||
- (sigma_t * h_phi_1) * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
return x_t, model_t
|
||||
|
||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
dims = x.dim()
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h[0] if self.predict_x0 else h[0]
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.variant == 'bh1':
|
||||
B_h = hh
|
||||
elif self.variant == 'bh2':
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if len(D1s.shape) > 5:
|
||||
D1s = rearrange(D1s, 'b k c f h w -> (b f) k c h w')
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
||||
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
||||
)
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
pred_res = repeat(pred_res, 'f c h w -> b c f h w', b=x.shape[0])
|
||||
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
corr_res = repeat(corr_res, 'f c h w -> b c f h w', b=x.shape[0])
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = (
|
||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||
)
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
pred_res = repeat(pred_res, 'f c h w -> b c f h w', b=x.shape[0])
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
corr_res = repeat(corr_res, 'f c h w -> b c f h w', b=x.shape[0])
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t, model_t
|
||||
|
||||
def handle_callback(self, callback):
|
||||
if callback is not None:
|
||||
callback()
|
||||
|
||||
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False, initial_corrector=True, callback=None
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
with torch.no_grad():
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=initial_corrector)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
|
||||
self.handle_callback(callback)
|
||||
|
||||
for step in range(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
|
||||
self.handle_callback(callback)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
#############################################################
|
||||
# other utility functions
|
||||
#############################################################
|
||||
|
||||
def interpolate_fn(x, xp, yp):
|
||||
"""
|
||||
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||
Args:
|
||||
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||
yp: PyTorch tensor with shape [C, K].
|
||||
Returns:
|
||||
The function values f(x), with shape [N, C].
|
||||
"""
|
||||
N, K = x.shape[0], xp.shape[1]
|
||||
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||
x_idx = torch.argmin(x_indices, dim=2)
|
||||
cand_start_idx = x_idx - 1
|
||||
start_idx = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(1, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||
start_idx2 = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(0, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
"""
|
||||
Expand the tensor `v` to the dim `dims`.
|
||||
Args:
|
||||
`v`: a PyTorch tensor with shape [N].
|
||||
`dim`: a `int`.
|
||||
Returns:
|
||||
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||
"""
|
||||
return v[(...,) + (None,)*(dims - 1)]
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
import gradio as gr
|
||||
from types import SimpleNamespace
|
||||
from t2v_helpers.video_audio_utils import find_ffmpeg_binary
|
||||
from samplers.samplers_common import available_samplers
|
||||
import os
|
||||
from modules.shared import opts
|
||||
|
||||
|
|
@ -32,27 +33,34 @@ ModelScope:
|
|||
|
||||
i1_store_t2v = f"<p style=\"text-align:center;font-weight:bold;margin-bottom:0em\">text2video extension for auto1111 — version 1.2b. The video will be shown below this label when ready</p>"
|
||||
|
||||
def enable_sampler_dropdown(model_type):
|
||||
is_visible = model_type == "ModelScope"
|
||||
return gr.update(visible=is_visible)
|
||||
|
||||
def setup_common_values(mode, d):
|
||||
with gr.Row(elem_id=f'{mode}_prompt_toprow'):
|
||||
prompt = gr.Textbox(label='Prompt', lines=3, interactive=True, elem_id=f"{mode}_prompt", placeholder="Enter your prompt here...")
|
||||
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
|
||||
n_prompt = gr.Textbox(label='Negative prompt', lines=2, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
|
||||
with gr.Row():
|
||||
sampler = gr.Dropdown(label="Sampling method (ModelScope)", choices=[x.name for x in available_samplers], value=available_samplers[0].name, elem_id="model-sampler", visible=True)
|
||||
steps = gr.Slider(label='Steps', minimum=1, maximum=100, step=1, value=d.steps)
|
||||
with gr.Row():
|
||||
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale)
|
||||
with gr.Row():
|
||||
width = gr.Slider(label='Width', minimum=64, maximum=1024, step=64, value=d.width)
|
||||
height = gr.Slider(label='Height', minimum=64, maximum=1024, step=64, value=d.height)
|
||||
with gr.Row():
|
||||
seed = gr.Number(label='Seed', value = d.seed, Interactive = True, precision=0)
|
||||
eta = gr.Number(label="ETA", value=d.eta, interactive=True)
|
||||
eta = gr.Number(label="ETA (DDIM Only)", value=d.eta, interactive=True)
|
||||
with gr.Row():
|
||||
gr.Markdown('256x256 Benchmarks: 24 frames peak at 5.7 GBs of VRAM and 125 frames peak at 11.5 GBs with Torch2 installed')
|
||||
with gr.Row():
|
||||
frames = gr.Slider(label="Frames", value=d.frames, minimum=2, maximum=250, step=1, interactive=True, precision=0)
|
||||
batch_count = gr.Slider(label="Batch count", value=d.batch_count, minimum=1, maximum=100, step=1, interactive=True)
|
||||
|
||||
return prompt, n_prompt, steps, seed, cfg_scale, width, height, eta, frames, batch_count
|
||||
return prompt, n_prompt, sampler, steps, seed, cfg_scale, width, height, eta, frames, batch_count
|
||||
|
||||
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
|
|
@ -82,11 +90,12 @@ def setup_text2video_settings_dictionary():
|
|||
do_vid2vid = gr.State(value=0)
|
||||
with gr.Tab('txt2vid') as tab_txt2vid:
|
||||
# TODO: make it how it's done in Deforum/WebUI, so we won't have to track individual vars
|
||||
prompt, n_prompt, steps, seed, cfg_scale, width, height, eta, frames, batch_count = setup_common_values('txt2vid', d)
|
||||
prompt, n_prompt, sampler, steps, seed, cfg_scale, width, height, eta, frames, batch_count = setup_common_values('txt2vid', d)
|
||||
model_type.change(fn=enable_sampler_dropdown, inputs=[model_type], outputs=[sampler])
|
||||
with gr.Accordion('img2vid', open=False):
|
||||
inpainting_image = gr.File(label="Inpainting image", interactive=True, file_count="single", file_types=["image"], elem_id="inpainting_chosen_file")
|
||||
# TODO: should be tied to the total frame count dynamically
|
||||
inpainting_frames=gr.Slider(label='inpainting frames',value=d.inpainting_frames,minimum=0, maximum=24, step=1)
|
||||
inpainting_frames=gr.Slider(label='inpainting frames',value=d.inpainting_frames,minimum=0, maximum=250, step=1)
|
||||
with gr.Row():
|
||||
gr.Markdown('''`inpainting frames` is the number of frames inpainting is applied to (counting from the beginning)
|
||||
|
||||
|
|
@ -99,11 +108,6 @@ To *loop it back*, set the weight to 0 for the first and for the last frame
|
|||
Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
|
||||
with gr.Row():
|
||||
inpainting_weights = gr.Textbox(label="Inpainting weights", value=d.inpainting_weights, interactive=True)
|
||||
|
||||
def update_max_inp_frames(f, i_frames): # Show video
|
||||
return gr.update(value=min(f, i_frames), maximum=f, visible=True)
|
||||
|
||||
frames.change(fn=update_max_inp_frames, inputs=[frames, inpainting_frames], outputs=[inpainting_frames])
|
||||
with gr.Tab('vid2vid') as tab_vid2vid:
|
||||
with gr.Row():
|
||||
gr.HTML('Put your video here')
|
||||
|
|
@ -114,15 +118,11 @@ Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
|
|||
with gr.Row():
|
||||
vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
|
||||
# TODO: here too
|
||||
prompt_v, n_prompt_v, steps_v, seed_v, cfg_scale_v, width_v, height_v, eta_v, frames_v, batch_count_v = setup_common_values('vid2vid', d)
|
||||
prompt_v, n_prompt_v, sampler_v, steps_v, seed_v, cfg_scale_v, width_v, height_v, eta_v, frames_v, batch_count_v = setup_common_values('vid2vid', d)
|
||||
model_type.change(fn=enable_sampler_dropdown, inputs=[model_type], outputs=[sampler_v])
|
||||
with gr.Row():
|
||||
strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||
vid2vid_startFrame=gr.Slider(label='vid2vid start frame',value=d.vid2vid_startFrame, minimum=0, maximum=23)
|
||||
|
||||
def update_max_vid_frames(v2v_frames, sFrame): # Show video
|
||||
return gr.update(value=min(sFrame, v2v_frames-1), maximum=v2v_frames-1, visible=True)
|
||||
|
||||
frames_v.change(fn=update_max_vid_frames, inputs=[frames_v, vid2vid_startFrame], outputs=[vid2vid_startFrame])
|
||||
vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
|
||||
|
||||
tab_txt2vid.select(fn=lambda: 0, inputs=[], outputs=[do_vid2vid])
|
||||
tab_vid2vid.select(fn=lambda: 1, inputs=[], outputs=[do_vid2vid])
|
||||
|
|
@ -148,7 +148,7 @@ Example: `0:(0), "max_i_f/4":(1), "3*max_i_f/4":(1), "max_i_f-1":(0)` ''')
|
|||
|
||||
t2v_video_args_names = str('skip_video_creation, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, fps, add_soundtrack, soundtrack_path').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
|
||||
|
||||
common_values_names = str('''prompt, n_prompt, steps, frames, seed, cfg_scale, width, height, eta, batch_count''').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
|
||||
common_values_names = str('''prompt, n_prompt, sampler, steps, frames, seed, cfg_scale, width, height, eta, batch_count''').replace("\n", "").replace("\r", "").replace(" ", "").split(',')
|
||||
|
||||
v2v_values_names = str('''
|
||||
do_vid2vid, vid2vid_frames, vid2vid_frames_path, strength,vid2vid_startFrame,
|
||||
|
|
@ -199,6 +199,7 @@ def T2VArgs():
|
|||
vid2vid_startFrame = 0
|
||||
inpainting_weights = '0:(t/max_i_f), "max_i_f":(1)' # linear growth weights (as they used to be in the original variant)
|
||||
inpainting_frames = 0
|
||||
sampler = "DDIM"
|
||||
return locals()
|
||||
|
||||
def T2VArgs_sanity_check(t2v_args):
|
||||
|
|
@ -219,6 +220,8 @@ def T2VArgs_sanity_check(t2v_args):
|
|||
raise ValueError('vid2vid start frame cannot be greater than the number of frames!')
|
||||
if t2v_args.inpainting_frames < 0 or t2v_args.inpainting_frames > t2v_args.frames:
|
||||
raise ValueError('inpainting frames count should lie between 0 and the frames number!')
|
||||
if not any([x.name == t2v_args.sampler for x in available_samplers]):
|
||||
raise ValueError("Sampler does not exist.")
|
||||
except Exception as e:
|
||||
print(t2v_args)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (C) 2023 by Artem Khrapov (kabachuha)
|
||||
# Read LICENSE for usage terms.
|
||||
from modules.prompt_parser import reconstruct_cond_batch
|
||||
|
||||
def get_t2v_version():
|
||||
from modules import extensions as mext
|
||||
|
|
@ -9,4 +10,9 @@ def get_t2v_version():
|
|||
return ext.version
|
||||
return "Unknown"
|
||||
except:
|
||||
return "Unknown"
|
||||
return "Unknown"
|
||||
|
||||
def reconstruct_conds(cond, uncond, step):
|
||||
c = reconstruct_cond_batch(cond, step)
|
||||
uc = reconstruct_cond_batch(uncond, step)
|
||||
return c, uc
|
||||
|
|
|
|||
|
|
@ -8,44 +8,36 @@ from videocrafter.process_videocrafter import process_videocrafter
|
|||
from modules.shared import opts
|
||||
from .error_hardcode import get_error
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
import logging
|
||||
import logging
|
||||
import gc
|
||||
import t2v_helpers.args as t2v_helpers_args
|
||||
|
||||
|
||||
def run(*args):
|
||||
dataurl = get_error()
|
||||
vids_pack = [dataurl]
|
||||
component_names = t2v_helpers_args.get_component_names()
|
||||
# api check
|
||||
affected_args = args[2:] if len(args) > 36 else args
|
||||
# TODO: change to i+2 when we will add the progress bar
|
||||
args_dict = {
|
||||
component_names[i]: affected_args[i] for i in range(0, len(component_names))
|
||||
}
|
||||
model_type = args_dict["model_type"]
|
||||
t2v_helpers_args.i1_store_t2v = f'<p style="font-weight:bold;margin-bottom:0em">text2video extension for auto1111 — version 1.2b </p><video controls loop><source src="{dataurl}" type="video/mp4"></video>'
|
||||
keep_pipe_in_vram = (
|
||||
opts.data.get("modelscope_deforum_keep_model_in_vram")
|
||||
if opts.data is not None
|
||||
and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None
|
||||
else "None"
|
||||
)
|
||||
args_dict = {component_names[i]: args[i+2] for i in range(0, len(component_names))}
|
||||
model_type = args_dict['model_type']
|
||||
t2v_helpers_args.i1_store_t2v = f'<p style=\"font-weight:bold;margin-bottom:0em\">text2video extension for auto1111 — version 1.2b </p><video controls loop><source src="{dataurl}" type="video/mp4"></video>'
|
||||
keep_pipe_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None'
|
||||
try:
|
||||
print("text2video — The model selected is: ", args_dict["model_type"])
|
||||
if model_type == "ModelScope":
|
||||
print('text2video — The model selected is: ', args_dict['model_type'])
|
||||
if model_type == 'ModelScope':
|
||||
vids_pack = process_modelscope(args_dict)
|
||||
elif model_type == "VideoCrafter (WIP)":
|
||||
elif model_type == 'VideoCrafter (WIP)':
|
||||
vids_pack = process_videocrafter(args_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown model type: {model_type}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print("Exception occurred:", e)
|
||||
print('Exception occurred:', e)
|
||||
finally:
|
||||
# optionally store pipe in global between runs, if not, remove it
|
||||
if keep_pipe_in_vram == "None":
|
||||
#optionally store pipe in global between runs, if not, remove it
|
||||
if keep_pipe_in_vram == 'None':
|
||||
pm.pipe = None
|
||||
devices.torch_gc()
|
||||
gc.collect()
|
||||
return vids_pack
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue