1745 lines
62 KiB
Python
1745 lines
62 KiB
Python
# Part of the implementation is borrowed and modified from stable-diffusion,
|
|
# publicly avaialbe at https://github.com/Stability-AI/stablediffusion.
|
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
|
|
|
# https://github.com/modelscope/modelscope/tree/master/modelscope/pipelines/multi_modal
|
|
|
|
# Alibaba's code used under Apache 2.0 license
|
|
# StabilityAI's Stable Diffusion code used under MIT license
|
|
# Automatic1111's WebUI's code used under AGPL v3.0
|
|
|
|
# All the licenses of the code and its modifications are incorporated into the compatible AGPL v3.0 license
|
|
|
|
# SD-webui text2video:
|
|
|
|
# Copyright (C) 2023 by Artem Khrapov (kabachuha)
|
|
# See LICENSE for usage terms.
|
|
|
|
from ldm.util import instantiate_from_config
|
|
import importlib
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from einops import rearrange, repeat
|
|
from os import path as osp
|
|
from modules.shared import opts
|
|
|
|
from functools import partial
|
|
from tqdm import tqdm
|
|
from modules.prompt_parser import reconstruct_cond_batch
|
|
from modules.shared import state
|
|
from modules.sd_samplers_common import InterruptedException
|
|
|
|
from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
|
|
from ldm.modules.diffusionmodules.util import make_beta_schedule
|
|
|
|
__all__ = ['UNetSD']
|
|
|
|
try:
|
|
import gc
|
|
import torch
|
|
import torch.cuda
|
|
|
|
def torch_gc():
|
|
"""Performs garbage collection for both Python and PyTorch CUDA tensors.
|
|
|
|
This function collects Python garbage and clears the PyTorch CUDA cache
|
|
and IPC (Inter-Process Communication) resources.
|
|
"""
|
|
gc.collect() # Collect Python garbage
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache() # Clear PyTorch CUDA cache
|
|
torch.cuda.ipc_collect() # Clear PyTorch CUDA IPC resources
|
|
|
|
except:
|
|
|
|
def torch_gc():
|
|
"""Dummy function when torch is not available.
|
|
|
|
This function does nothing and serves as a placeholder when torch is
|
|
not available, allowing the rest of the code to run without errors.
|
|
"""
|
|
gc.collect()
|
|
pass
|
|
|
|
import modules.shared as shared
|
|
from modules.shared import cmd_opts
|
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
|
|
|
from ldm.modules.diffusionmodules.model import Decoder, Encoder
|
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
|
|
|
DEFAULT_MODEL_REVISION = None
|
|
|
|
|
|
class Invoke(object):
|
|
KEY = 'invoked_by'
|
|
PRETRAINED = 'from_pretrained'
|
|
PIPELINE = 'pipeline'
|
|
TRAINER = 'trainer'
|
|
LOCAL_TRAINER = 'local_trainer'
|
|
PREPROCESSOR = 'preprocessor'
|
|
|
|
|
|
def exists(x):
|
|
return x is not None
|
|
|
|
|
|
def default(val, d):
|
|
if exists(val):
|
|
return val
|
|
return d() if callable(d) else d
|
|
|
|
|
|
class UNetSD(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_dim=7,
|
|
dim=512,
|
|
y_dim=512,
|
|
context_dim=512,
|
|
out_dim=6,
|
|
dim_mult=[1, 2, 3, 4],
|
|
num_heads=None,
|
|
head_dim=64,
|
|
num_res_blocks=3,
|
|
attn_scales=[1 / 2, 1 / 4, 1 / 8],
|
|
use_scale_shift_norm=True,
|
|
dropout=0.1,
|
|
temporal_attn_times=2,
|
|
temporal_attention=True,
|
|
use_checkpoint=False,
|
|
use_image_dataset=False,
|
|
use_fps_condition=False,
|
|
use_sim_mask=False,
|
|
parameterization="eps"):
|
|
embed_dim = dim * 4
|
|
num_heads = num_heads if num_heads else dim // 32
|
|
super(UNetSD, self).__init__()
|
|
self.in_dim = in_dim
|
|
self.dim = dim
|
|
self.y_dim = y_dim
|
|
self.context_dim = context_dim
|
|
self.embed_dim = embed_dim
|
|
self.out_dim = out_dim
|
|
self.dim_mult = dim_mult
|
|
self.num_heads = num_heads
|
|
# parameters for spatial/temporal attention
|
|
self.head_dim = head_dim
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attn_scales = attn_scales
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
self.temporal_attn_times = temporal_attn_times
|
|
self.temporal_attention = temporal_attention
|
|
self.use_checkpoint = use_checkpoint
|
|
self.use_image_dataset = use_image_dataset
|
|
self.use_fps_condition = use_fps_condition
|
|
self.use_sim_mask = use_sim_mask
|
|
self.parameterization = parameterization
|
|
self.v_posterior = 0
|
|
use_linear_in_temporal = False
|
|
transformer_depth = 1
|
|
disabled_sa = False
|
|
# params
|
|
enc_dims = [dim * u for u in [1] + dim_mult]
|
|
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
|
shortcut_dims = []
|
|
scale = 1.0
|
|
|
|
# embeddings
|
|
self.time_embed = nn.Sequential(
|
|
nn.Linear(dim, embed_dim), nn.SiLU(),
|
|
nn.Linear(embed_dim, embed_dim))
|
|
|
|
if self.use_fps_condition:
|
|
self.fps_embedding = nn.Sequential(
|
|
nn.Linear(dim, embed_dim), nn.SiLU(),
|
|
nn.Linear(embed_dim, embed_dim))
|
|
nn.init.zeros_(self.fps_embedding[-1].weight)
|
|
nn.init.zeros_(self.fps_embedding[-1].bias)
|
|
|
|
# encoder
|
|
self.input_blocks = nn.ModuleList()
|
|
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
|
|
|
if temporal_attention:
|
|
init_block.append(
|
|
TemporalTransformer(
|
|
dim,
|
|
num_heads,
|
|
head_dim,
|
|
depth=transformer_depth,
|
|
context_dim=context_dim,
|
|
disable_self_attn=disabled_sa,
|
|
use_linear=use_linear_in_temporal,
|
|
multiply_zero=use_image_dataset))
|
|
|
|
self.input_blocks.append(init_block)
|
|
shortcut_dims.append(dim)
|
|
for i, (in_dim,
|
|
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
|
for j in range(num_res_blocks):
|
|
# residual (+attention) blocks
|
|
block = nn.ModuleList([
|
|
ResBlock(
|
|
in_dim,
|
|
embed_dim,
|
|
dropout,
|
|
out_channels=out_dim,
|
|
use_scale_shift_norm=False,
|
|
use_image_dataset=use_image_dataset,
|
|
)
|
|
])
|
|
if scale in attn_scales:
|
|
block.append(
|
|
SpatialTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=1,
|
|
context_dim=self.context_dim,
|
|
disable_self_attn=False,
|
|
use_linear=True))
|
|
if self.temporal_attention:
|
|
block.append(
|
|
TemporalTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=transformer_depth,
|
|
context_dim=context_dim,
|
|
disable_self_attn=disabled_sa,
|
|
use_linear=use_linear_in_temporal,
|
|
multiply_zero=use_image_dataset))
|
|
|
|
in_dim = out_dim
|
|
self.input_blocks.append(block)
|
|
shortcut_dims.append(out_dim)
|
|
|
|
# downsample
|
|
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
|
downsample = Downsample(
|
|
out_dim, True, dims=2, out_channels=out_dim)
|
|
shortcut_dims.append(out_dim)
|
|
scale /= 2.0
|
|
self.input_blocks.append(downsample)
|
|
|
|
# middle
|
|
self.middle_block = nn.ModuleList([
|
|
ResBlock(
|
|
out_dim,
|
|
embed_dim,
|
|
dropout,
|
|
use_scale_shift_norm=False,
|
|
use_image_dataset=use_image_dataset,
|
|
),
|
|
SpatialTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=1,
|
|
context_dim=self.context_dim,
|
|
disable_self_attn=False,
|
|
use_linear=True)
|
|
])
|
|
|
|
if self.temporal_attention:
|
|
self.middle_block.append(
|
|
TemporalTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=transformer_depth,
|
|
context_dim=context_dim,
|
|
disable_self_attn=disabled_sa,
|
|
use_linear=use_linear_in_temporal,
|
|
multiply_zero=use_image_dataset,
|
|
))
|
|
|
|
self.middle_block.append(
|
|
ResBlock(
|
|
out_dim,
|
|
embed_dim,
|
|
dropout,
|
|
use_scale_shift_norm=False,
|
|
use_image_dataset=use_image_dataset,
|
|
))
|
|
|
|
# decoder
|
|
self.output_blocks = nn.ModuleList()
|
|
for i, (in_dim,
|
|
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
|
for j in range(num_res_blocks + 1):
|
|
# residual (+attention) blocks
|
|
block = nn.ModuleList([
|
|
ResBlock(
|
|
in_dim + shortcut_dims.pop(),
|
|
embed_dim,
|
|
dropout,
|
|
out_dim,
|
|
use_scale_shift_norm=False,
|
|
use_image_dataset=use_image_dataset,
|
|
)
|
|
])
|
|
if scale in attn_scales:
|
|
block.append(
|
|
SpatialTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=1,
|
|
context_dim=1024,
|
|
disable_self_attn=False,
|
|
use_linear=True))
|
|
|
|
if self.temporal_attention:
|
|
block.append(
|
|
TemporalTransformer(
|
|
out_dim,
|
|
out_dim // head_dim,
|
|
head_dim,
|
|
depth=transformer_depth,
|
|
context_dim=context_dim,
|
|
disable_self_attn=disabled_sa,
|
|
use_linear=use_linear_in_temporal,
|
|
multiply_zero=use_image_dataset))
|
|
in_dim = out_dim
|
|
|
|
# upsample
|
|
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
|
upsample = Upsample(
|
|
out_dim, True, dims=2.0, out_channels=out_dim)
|
|
scale *= 2.0
|
|
block.append(upsample)
|
|
self.output_blocks.append(block)
|
|
|
|
# head
|
|
self.out = nn.Sequential(
|
|
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
|
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
|
|
|
# zero out the last layer params
|
|
nn.init.zeros_(self.out[-1].weight)
|
|
|
|
# Taken from DDPM
|
|
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
|
|
if exists(given_betas):
|
|
betas = given_betas
|
|
else:
|
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
|
cosine_s=cosine_s)
|
|
alphas = 1. - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
|
|
timesteps, = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.linear_start = linear_start
|
|
self.linear_end = linear_end
|
|
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
|
self.register_buffer('betas', to_torch(betas))
|
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
|
1. - alphas_cumprod) + self.v_posterior * betas
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
|
self.register_buffer('posterior_mean_coef1', to_torch(
|
|
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
|
self.register_buffer('posterior_mean_coef2', to_torch(
|
|
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
|
|
|
if self.parameterization == "eps":
|
|
lvlb_weights = self.betas ** 2 / (
|
|
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
|
elif self.parameterization == "x0":
|
|
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
|
elif self.parameterization == "v":
|
|
lvlb_weights = torch.ones_like(self.betas ** 2 / (
|
|
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
|
|
else:
|
|
raise NotImplementedError("mu not supported")
|
|
lvlb_weights[0] = lvlb_weights[1]
|
|
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
|
assert not torch.isnan(self.lvlb_weights).all()
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
t,
|
|
y,
|
|
fps=None,
|
|
video_mask=None,
|
|
focus_present_mask=None,
|
|
prob_focus_present=0.0,
|
|
mask_last_frame_num=0 # mask last frame num
|
|
):
|
|
"""
|
|
prob_focus_present: probability at which a given batch sample will focus on the present
|
|
(0. is all off, 1. is completely arrested attention across time)
|
|
"""
|
|
batch, device = x.shape[0], x.device
|
|
self.batch = batch
|
|
|
|
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
|
if mask_last_frame_num > 0:
|
|
focus_present_mask = None
|
|
video_mask[-mask_last_frame_num:] = False
|
|
else:
|
|
focus_present_mask = default(
|
|
focus_present_mask, lambda: prob_mask_like(
|
|
(batch, ), prob_focus_present, device=device))
|
|
|
|
time_rel_pos_bias = None
|
|
# embeddings
|
|
if self.use_fps_condition and fps is not None:
|
|
e = self.time_embed(sinusoidal_embedding(
|
|
t, self.dim)) + self.fps_embedding(
|
|
sinusoidal_embedding(fps, self.dim))
|
|
else:
|
|
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
|
context = y
|
|
|
|
# repeat f times for spatial e and context
|
|
f = x.shape[2]
|
|
e = e.repeat_interleave(repeats=f, dim=0)
|
|
context = context.repeat_interleave(repeats=f, dim=0)
|
|
|
|
# always in shape (b f) c h w, except for temporal layer
|
|
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
|
# encoder
|
|
xs = []
|
|
for block in self.input_blocks:
|
|
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
|
focus_present_mask, video_mask)
|
|
xs.append(x)
|
|
|
|
# middle
|
|
for block in self.middle_block:
|
|
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
|
focus_present_mask, video_mask)
|
|
|
|
# decoder
|
|
for block in self.output_blocks:
|
|
x = torch.cat([x, xs.pop()], dim=1)
|
|
x = self._forward_single(
|
|
block,
|
|
x,
|
|
e,
|
|
context,
|
|
time_rel_pos_bias,
|
|
focus_present_mask,
|
|
video_mask,
|
|
reference=xs[-1] if len(xs) > 0 else None)
|
|
|
|
# head
|
|
x = self.out(x)
|
|
# reshape back to (b c f h w)
|
|
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
|
|
return x
|
|
|
|
def _forward_single(self,
|
|
module,
|
|
x,
|
|
e,
|
|
context,
|
|
time_rel_pos_bias,
|
|
focus_present_mask,
|
|
video_mask,
|
|
reference=None):
|
|
if isinstance(module, ResidualBlock):
|
|
x = x.contiguous()
|
|
x = module(x, e, reference)
|
|
elif isinstance(module, ResBlock):
|
|
x = x.contiguous()
|
|
x = module(x, e, self.batch)
|
|
elif isinstance(module, SpatialTransformer):
|
|
x = module(x, context)
|
|
elif isinstance(module, TemporalTransformer):
|
|
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
|
x = module(x, context)
|
|
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
|
elif isinstance(module, CrossAttention):
|
|
x = module(x, context)
|
|
elif isinstance(module, BasicTransformerBlock):
|
|
x = module(x, context)
|
|
elif isinstance(module, FeedForward):
|
|
x = module(x, context)
|
|
elif isinstance(module, Upsample):
|
|
x = module(x)
|
|
elif isinstance(module, Downsample):
|
|
x = module(x)
|
|
elif isinstance(module, Resample):
|
|
x = module(x, reference)
|
|
elif isinstance(module, nn.ModuleList):
|
|
for block in module:
|
|
x = self._forward_single(block, x, e, context,
|
|
time_rel_pos_bias, focus_present_mask,
|
|
video_mask, reference)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
|
|
def sinusoidal_embedding(timesteps, dim):
|
|
# check input
|
|
half = dim // 2
|
|
timesteps = timesteps.float()
|
|
# compute sinusoidal embedding
|
|
sinusoid = torch.outer(
|
|
timesteps, torch.pow(10000,
|
|
-torch.arange(half).to(timesteps).div(half)))
|
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
if dim % 2 != 0:
|
|
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
|
return x
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
|
def __init__(self,
|
|
query_dim,
|
|
context_dim=None,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = default(context_dim, query_dim)
|
|
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
|
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
|
|
|
def forward(self, x, context=None, mask=None):
|
|
h = self.heads
|
|
|
|
q = self.to_q(x)
|
|
context = default(context, x)
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
(q, k, v))
|
|
|
|
if exists(mask):
|
|
mask = rearrange(mask, 'b ... -> b (...)')
|
|
max_neg_value = -torch.finfo(x.dtype).max
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
|
|
if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
|
import xformers
|
|
out = xformers.ops.memory_efficient_attention(
|
|
q, k, v, op=get_xformers_flash_attention_op(q,k,v), attn_bias=mask,
|
|
)
|
|
elif getattr(cmd_opts, "opt_sdp_no_mem_attention", False) and can_use_sdp:
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
|
out = F.scaled_dot_product_attention(
|
|
q, k, v, dropout_p=0.0, attn_mask=mask
|
|
)
|
|
elif getattr(cmd_opts, "opt_sdp_attention", True) and can_use_sdp:
|
|
out = F.scaled_dot_product_attention(
|
|
q, k, v, dropout_p=0.0, attn_mask=mask
|
|
)
|
|
else:
|
|
|
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
del q, k
|
|
|
|
if exists(mask):
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
|
|
# attention, what we cannot get enough of
|
|
sim = sim.softmax(dim=-1)
|
|
|
|
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
return self.to_out(out)
|
|
|
|
|
|
class SpatialTransformer(nn.Module):
|
|
"""
|
|
Transformer block for image-like data in spatial axis.
|
|
First, project the input (aka embedding)
|
|
and reshape to b, t, d.
|
|
Then apply standard transformer action.
|
|
Finally, reshape to image
|
|
NEW: use_linear for more efficiency instead of the 1x1 convs
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
disable_self_attn=False,
|
|
use_linear=False,
|
|
use_checkpoint=True):
|
|
super().__init__()
|
|
if exists(context_dim) and not isinstance(context_dim, list):
|
|
context_dim = [context_dim]
|
|
self.in_channels = in_channels
|
|
inner_dim = n_heads * d_head
|
|
self.norm = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
if not use_linear:
|
|
self.proj_in = nn.Conv2d(
|
|
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
else:
|
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
BasicTransformerBlock(
|
|
inner_dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=dropout,
|
|
context_dim=context_dim[d],
|
|
disable_self_attn=disable_self_attn,
|
|
checkpoint=use_checkpoint) for d in range(depth)
|
|
])
|
|
if not use_linear:
|
|
self.proj_out = zero_module(
|
|
nn.Conv2d(
|
|
inner_dim, in_channels, kernel_size=1, stride=1,
|
|
padding=0))
|
|
else:
|
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
|
self.use_linear = use_linear
|
|
|
|
def forward(self, x, context=None):
|
|
# note: if no context is given, cross-attention defaults to self-attention
|
|
if not isinstance(context, list):
|
|
context = [context]
|
|
b, c, h, w = x.shape
|
|
x_in = x
|
|
x = self.norm(x)
|
|
if not self.use_linear:
|
|
x = self.proj_in(x)
|
|
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
|
if self.use_linear:
|
|
x = self.proj_in(x)
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
x = block(x, context=context[i])
|
|
if self.use_linear:
|
|
x = self.proj_out(x)
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
|
if not self.use_linear:
|
|
x = self.proj_out(x)
|
|
return x + x_in
|
|
|
|
|
|
class TemporalTransformer(nn.Module):
|
|
"""
|
|
Transformer block for image-like data in temporal axis.
|
|
First, reshape to b, t, d.
|
|
Then apply standard transformer action.
|
|
Finally, reshape to image
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
disable_self_attn=False,
|
|
use_linear=False,
|
|
use_checkpoint=True,
|
|
only_self_att=True,
|
|
multiply_zero=False):
|
|
super().__init__()
|
|
self.multiply_zero = multiply_zero
|
|
self.only_self_att = only_self_att
|
|
if self.only_self_att:
|
|
context_dim = None
|
|
if not isinstance(context_dim, list):
|
|
context_dim = [context_dim]
|
|
self.in_channels = in_channels
|
|
inner_dim = n_heads * d_head
|
|
self.norm = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
if not use_linear:
|
|
self.proj_in = nn.Conv1d(
|
|
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
else:
|
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
BasicTransformerBlock(
|
|
inner_dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=dropout,
|
|
context_dim=context_dim[d],
|
|
checkpoint=use_checkpoint) for d in range(depth)
|
|
])
|
|
if not use_linear:
|
|
self.proj_out = zero_module(
|
|
nn.Conv1d(
|
|
inner_dim, in_channels, kernel_size=1, stride=1,
|
|
padding=0))
|
|
else:
|
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
|
self.use_linear = use_linear
|
|
|
|
def forward(self, x, context=None):
|
|
# note: if no context is given, cross-attention defaults to self-attention
|
|
if self.only_self_att:
|
|
context = None
|
|
if not isinstance(context, list):
|
|
context = [context]
|
|
b, c, f, h, w = x.shape
|
|
x_in = x
|
|
x = self.norm(x)
|
|
|
|
if not self.use_linear:
|
|
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
|
x = self.proj_in(x)
|
|
if self.use_linear:
|
|
x = rearrange(
|
|
x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
|
|
x = self.proj_in(x)
|
|
|
|
if self.only_self_att:
|
|
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
x = block(x)
|
|
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
|
else:
|
|
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
context[i] = rearrange(
|
|
context[i], '(b f) l con -> b f l con',
|
|
f=self.frames).contiguous()
|
|
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
|
for j in range(b):
|
|
context_i_j = repeat(
|
|
context[i][j],
|
|
'f l con -> (f r) l con',
|
|
r=(h * w) // self.frames,
|
|
f=self.frames).contiguous()
|
|
x[j] = block(x[j], context=context_i_j)
|
|
|
|
if self.use_linear:
|
|
x = self.proj_out(x)
|
|
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
|
|
if not self.use_linear:
|
|
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
|
x = self.proj_out(x)
|
|
x = rearrange(
|
|
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
|
|
|
if self.multiply_zero:
|
|
x = 0.0 * x + x_in
|
|
else:
|
|
x = x + x_in
|
|
return x
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
|
|
def __init__(self,
|
|
dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
gated_ff=True,
|
|
checkpoint=True,
|
|
disable_self_attn=False):
|
|
super().__init__()
|
|
attn_cls = CrossAttention
|
|
self.disable_self_attn = disable_self_attn
|
|
self.attn1 = attn_cls(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout,
|
|
context_dim=context_dim if self.disable_self_attn else
|
|
None) # is a self-attention if not self.disable_self_attn
|
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
|
self.attn2 = attn_cls(
|
|
query_dim=dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout) # is self-attn if context is none
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
self.checkpoint = checkpoint
|
|
|
|
def forward(self, x, context=None):
|
|
x = self.attn1(
|
|
self.norm1(x),
|
|
context=context if self.disable_self_attn else None) + x
|
|
x = self.attn2(self.norm2(x), context=context) + x
|
|
x = self.ff(self.norm3(x)) + x
|
|
return x
|
|
|
|
|
|
# feedforward
|
|
class GEGLU(nn.Module):
|
|
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
|
|
def forward(self, x):
|
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
return x * F.gelu(gate)
|
|
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
dim_out = default(dim_out, dim)
|
|
project_in = nn.Sequential(nn.Linear(
|
|
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
|
|
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
|
nn.Linear(inner_dim, dim_out))
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
An upsampling layer with an optional convolution.
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
upsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
use_conv,
|
|
dims=2,
|
|
out_channels=None,
|
|
padding=1):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
if use_conv:
|
|
self.conv = nn.Conv2d(
|
|
self.channels, self.out_channels, 3, padding=padding)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.dims == 3:
|
|
x = F.interpolate(
|
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
|
mode='nearest')
|
|
else:
|
|
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
|
if self.use_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
"""
|
|
A residual block that can optionally change the number of channels.
|
|
:param channels: the number of input channels.
|
|
:param emb_channels: the number of timestep embedding channels.
|
|
:param dropout: the rate of dropout.
|
|
:param out_channels: if specified, the number of out channels.
|
|
:param use_conv: if True and out_channels is specified, use a spatial
|
|
convolution instead of a smaller 1x1 convolution to change the
|
|
channels in the skip connection.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
:param up: if True, use this block for upsampling.
|
|
:param down: if True, use this block for downsampling.
|
|
:param use_temporal_conv: if True, use the temporal convolution.
|
|
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
emb_channels,
|
|
dropout,
|
|
out_channels=None,
|
|
use_conv=False,
|
|
use_scale_shift_norm=False,
|
|
dims=2,
|
|
up=False,
|
|
down=False,
|
|
use_temporal_conv=True,
|
|
use_image_dataset=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.emb_channels = emb_channels
|
|
self.dropout = dropout
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
self.use_temporal_conv = use_temporal_conv
|
|
|
|
self.in_layers = nn.Sequential(
|
|
nn.GroupNorm(32, channels),
|
|
nn.SiLU(),
|
|
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
|
)
|
|
|
|
self.updown = up or down
|
|
|
|
if up:
|
|
self.h_upd = Upsample(channels, False, dims)
|
|
self.x_upd = Upsample(channels, False, dims)
|
|
elif down:
|
|
self.h_upd = Downsample(channels, False, dims)
|
|
self.x_upd = Downsample(channels, False, dims)
|
|
else:
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
self.emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(
|
|
emb_channels,
|
|
2 * self.out_channels
|
|
if use_scale_shift_norm else self.out_channels,
|
|
),
|
|
)
|
|
self.out_layers = nn.Sequential(
|
|
nn.GroupNorm(32, self.out_channels),
|
|
nn.SiLU(),
|
|
nn.Dropout(p=dropout),
|
|
zero_module(
|
|
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
|
)
|
|
|
|
if self.out_channels == channels:
|
|
self.skip_connection = nn.Identity()
|
|
elif use_conv:
|
|
self.skip_connection = conv_nd(
|
|
dims, channels, self.out_channels, 3, padding=1)
|
|
else:
|
|
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
|
|
|
if self.use_temporal_conv:
|
|
self.temopral_conv = TemporalConvBlock_v2(
|
|
self.out_channels,
|
|
self.out_channels,
|
|
dropout=0.1,
|
|
use_image_dataset=use_image_dataset)
|
|
|
|
def forward(self, x, emb, batch_size):
|
|
"""
|
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
:param x: an [N x C x ...] Tensor of features.
|
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
"""
|
|
return self._forward(x, emb, batch_size)
|
|
|
|
def _forward(self, x, emb, batch_size):
|
|
if self.updown:
|
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
h = in_rest(x)
|
|
h = self.h_upd(h)
|
|
x = self.x_upd(x)
|
|
h = in_conv(h)
|
|
else:
|
|
h = self.in_layers(x)
|
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
|
while len(emb_out.shape) < len(h.shape):
|
|
emb_out = emb_out[..., None]
|
|
if self.use_scale_shift_norm:
|
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
|
h = out_norm(h) * (1 + scale) + shift
|
|
h = out_rest(h)
|
|
else:
|
|
h = h + emb_out
|
|
h = self.out_layers(h)
|
|
h = self.skip_connection(x) + h
|
|
|
|
if self.use_temporal_conv:
|
|
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
|
|
h = self.temopral_conv(h)
|
|
h = rearrange(h, 'b c f h w -> (b f) c h w')
|
|
return h
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
use_conv,
|
|
dims=2,
|
|
out_channels=None,
|
|
padding=1):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
stride = 2 if dims != 3 else (1, 2, 2)
|
|
if self.use_conv:
|
|
self.op = nn.Conv2d(
|
|
self.channels,
|
|
self.out_channels,
|
|
3,
|
|
stride=stride,
|
|
padding=padding)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
return self.op(x)
|
|
|
|
|
|
class Resample(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, mode):
|
|
assert mode in ['none', 'upsample', 'downsample']
|
|
super(Resample, self).__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
self.mode = mode
|
|
|
|
def forward(self, x, reference=None):
|
|
if self.mode == 'upsample':
|
|
assert reference is not None
|
|
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
|
|
elif self.mode == 'downsample':
|
|
x = F.adaptive_avg_pool2d(
|
|
x, output_size=tuple(u // 2 for u in x.shape[-2:]))
|
|
return x
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_dim,
|
|
embed_dim,
|
|
out_dim,
|
|
use_scale_shift_norm=True,
|
|
mode='none',
|
|
dropout=0.0):
|
|
super(ResidualBlock, self).__init__()
|
|
self.in_dim = in_dim
|
|
self.embed_dim = embed_dim
|
|
self.out_dim = out_dim
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
self.mode = mode
|
|
|
|
# layers
|
|
self.layer1 = nn.Sequential(
|
|
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
|
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
|
self.resample = Resample(in_dim, in_dim, mode)
|
|
self.embedding = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(embed_dim,
|
|
out_dim * 2 if use_scale_shift_norm else out_dim))
|
|
self.layer2 = nn.Sequential(
|
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
|
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
|
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
|
in_dim, out_dim, 1)
|
|
# zero out the last layer params
|
|
nn.init.zeros_(self.layer2[-1].weight)
|
|
|
|
def forward(self, x, e, reference=None):
|
|
identity = self.resample(x, reference)
|
|
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
|
|
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
|
if self.use_scale_shift_norm:
|
|
scale, shift = e.chunk(2, dim=1)
|
|
x = self.layer2[0](x) * (1 + scale) + shift
|
|
x = self.layer2[1:](x)
|
|
else:
|
|
x = x + e
|
|
x = self.layer2(x)
|
|
x = x + self.shortcut(identity)
|
|
return x
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
|
|
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
|
# consider head_dim first, then num_heads
|
|
num_heads = dim // head_dim if head_dim else num_heads
|
|
head_dim = dim // num_heads
|
|
assert num_heads * head_dim == dim
|
|
super(AttentionBlock, self).__init__()
|
|
self.dim = dim
|
|
self.context_dim = context_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = math.pow(head_dim, -0.25)
|
|
|
|
# layers
|
|
self.norm = nn.GroupNorm(32, dim)
|
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
|
if context_dim is not None:
|
|
self.context_kv = nn.Linear(context_dim, dim * 2)
|
|
self.proj = nn.Conv2d(dim, dim, 1)
|
|
|
|
# zero out the last layer params
|
|
nn.init.zeros_(self.proj.weight)
|
|
|
|
def forward(self, x, context=None):
|
|
r"""x: [B, C, H, W].
|
|
context: [B, L, C] or None.
|
|
"""
|
|
identity = x
|
|
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
|
|
|
# compute query, key, value
|
|
x = self.norm(x)
|
|
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
|
if context is not None:
|
|
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
|
d).permute(0, 2, 3,
|
|
1).chunk(
|
|
2, dim=1)
|
|
k = torch.cat([ck, k], dim=-1)
|
|
v = torch.cat([cv, v], dim=-1)
|
|
|
|
# compute attention
|
|
|
|
if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
|
import xformers
|
|
x = xformers.ops.memory_efficient_attention(
|
|
q, k, v, op=get_xformers_flash_attention_op(q,k,v),
|
|
)
|
|
elif getattr(cmd_opts, "opt_sdp_no_mem_attention", False) and can_use_sdp:
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
|
x = F.scaled_dot_product_attention(
|
|
q, k, v, dropout_p=0.0,
|
|
)
|
|
elif getattr(cmd_opts, "opt_sdp_attention", True) and can_use_sdp:
|
|
x = F.scaled_dot_product_attention(
|
|
q, k, v, dropout_p=0.0,
|
|
)
|
|
else:
|
|
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
|
# gather context
|
|
x = torch.matmul(v, attn.transpose(-1, -2))
|
|
x = x.reshape(b, c, h, w)
|
|
# output
|
|
x = self.proj(x)
|
|
return x + identity
|
|
|
|
|
|
class TemporalConvBlock_v2(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_dim,
|
|
out_dim=None,
|
|
dropout=0.0,
|
|
use_image_dataset=False):
|
|
super(TemporalConvBlock_v2, self).__init__()
|
|
if out_dim is None:
|
|
out_dim = in_dim # int(1.5*in_dim)
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
self.use_image_dataset = use_image_dataset
|
|
|
|
# conv layers
|
|
self.conv1 = nn.Sequential(
|
|
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
|
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
|
|
self.conv2 = nn.Sequential(
|
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
|
self.conv3 = nn.Sequential(
|
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
|
self.conv4 = nn.Sequential(
|
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
|
|
|
# zero out the last layer params,so the conv block is identity
|
|
nn.init.zeros_(self.conv4[-1].weight)
|
|
nn.init.zeros_(self.conv4[-1].bias)
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
x = self.conv3(x)
|
|
x = self.conv4(x)
|
|
|
|
if self.use_image_dataset:
|
|
x = identity + 0.0 * x
|
|
else:
|
|
x = identity + x
|
|
return x
|
|
|
|
|
|
def _i(tensor, t, x):
|
|
r"""Index tensor using t and format the output according to x.
|
|
"""
|
|
tensor = tensor.to(x.device)
|
|
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
|
return tensor[t].view(shape).to(x)
|
|
|
|
|
|
def beta_schedule(schedule,
|
|
num_timesteps=1000,
|
|
init_beta=None,
|
|
last_beta=None):
|
|
if schedule == 'linear_sd':
|
|
return torch.linspace(
|
|
init_beta**0.5, last_beta**0.5, num_timesteps,
|
|
dtype=torch.float64)**2
|
|
else:
|
|
raise ValueError(f'Unsupported schedule: {schedule}')
|
|
|
|
|
|
class GaussianDiffusion(object):
|
|
r""" Diffusion Model for DDIM.
|
|
"Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon.
|
|
See https://arxiv.org/abs/2010.02502
|
|
"""
|
|
|
|
def __init__(self,
|
|
betas,
|
|
mean_type='eps',
|
|
var_type='learned_range',
|
|
loss_type='mse',
|
|
epsilon=1e-12,
|
|
rescale_timesteps=False):
|
|
# check input
|
|
if not isinstance(betas, torch.DoubleTensor):
|
|
betas = torch.tensor(betas, dtype=torch.float64)
|
|
assert min(betas) > 0 and max(betas) <= 1
|
|
assert mean_type in ['x0', 'x_{t-1}', 'eps']
|
|
assert var_type in [
|
|
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
|
]
|
|
assert loss_type in [
|
|
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
|
|
'charbonnier'
|
|
]
|
|
self.betas = betas
|
|
self.num_timesteps = len(betas)
|
|
self.mean_type = mean_type
|
|
self.var_type = var_type
|
|
self.loss_type = loss_type
|
|
self.epsilon = epsilon
|
|
self.rescale_timesteps = rescale_timesteps
|
|
|
|
# alphas
|
|
alphas = 1 - self.betas
|
|
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
self.alphas_cumprod_prev = torch.cat(
|
|
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
|
self.alphas_cumprod_next = torch.cat(
|
|
[self.alphas_cumprod[1:],
|
|
alphas.new_zeros([1])])
|
|
|
|
# q(x_t | x_{t-1})
|
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
|
- self.alphas_cumprod)
|
|
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
|
- self.alphas_cumprod)
|
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
|
- 1)
|
|
|
|
# q(x_{t-1} | x_t, x_0)
|
|
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
|
1.0 - self.alphas_cumprod)
|
|
self.posterior_log_variance_clipped = torch.log(
|
|
self.posterior_variance.clamp(1e-20))
|
|
self.posterior_mean_coef1 = betas * torch.sqrt(
|
|
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
|
self.posterior_mean_coef2 = (
|
|
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
|
1.0 - self.alphas_cumprod)
|
|
|
|
def add_noise(self, xt, noise, t):
|
|
#print("adding noise", t,
|
|
# self.sqrt_alphas_cumprod[t], self.sqrt_one_minus_alphas_cumprod[t])
|
|
noisy_sample = self.sqrt_alphas_cumprod[t] * \
|
|
xt+noise*self.sqrt_one_minus_alphas_cumprod[t]
|
|
return noisy_sample
|
|
|
|
def p_mean_variance(self,
|
|
xt,
|
|
t,
|
|
model,
|
|
model_kwargs={},
|
|
clamp=None,
|
|
percentile=None,
|
|
guide_scale=None):
|
|
r"""Distribution of p(x_{t-1} | x_t).
|
|
"""
|
|
# predict distribution
|
|
if guide_scale is None or guide_scale == 1:
|
|
out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
|
else:
|
|
# classifier-free guidance
|
|
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
|
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
|
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
|
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
|
dim = y_out.size(1) if self.var_type.startswith(
|
|
'fixed') else y_out.size(1) // 2
|
|
a = u_out[:, :dim]
|
|
b = guide_scale * (y_out[:, :dim] - u_out[:, :dim])
|
|
c = y_out[:, dim:]
|
|
out = torch.cat([a + b, c], dim=1)
|
|
|
|
# compute variance
|
|
if self.var_type == 'fixed_small':
|
|
var = _i(self.posterior_variance, t, xt)
|
|
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
|
|
|
# compute mean and x0
|
|
if self.mean_type == 'eps':
|
|
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
|
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
|
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
|
|
|
# restrict the range of x0
|
|
if percentile is not None:
|
|
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
|
s = torch.quantile(
|
|
x0.flatten(1).abs(), percentile,
|
|
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
|
x0 = torch.min(s, torch.max(-s, x0)) / s
|
|
elif clamp is not None:
|
|
x0 = x0.clamp(-clamp, clamp)
|
|
return mu, var, log_var, x0
|
|
|
|
def q_posterior_mean_variance(self, x0, xt, t):
|
|
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
|
"""
|
|
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
|
self.posterior_mean_coef2, t, xt) * xt
|
|
var = _i(self.posterior_variance, t, xt)
|
|
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
|
return mu, var, log_var
|
|
|
|
@torch.no_grad()
|
|
def ddim_sample(self,
|
|
xt,
|
|
t,
|
|
model,
|
|
model_kwargs={},
|
|
clamp=None,
|
|
percentile=None,
|
|
condition_fn=None,
|
|
guide_scale=None,
|
|
ddim_timesteps=20,
|
|
eta=0.0):
|
|
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
|
- condition_fn: for classifier-based guidance (guided-diffusion).
|
|
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
|
"""
|
|
stride = self.num_timesteps // ddim_timesteps
|
|
|
|
# predict distribution of p(x_{t-1} | x_t)
|
|
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
|
percentile, guide_scale)
|
|
if condition_fn is not None:
|
|
# x0 -> eps
|
|
alpha = _i(self.alphas_cumprod, t, xt)
|
|
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
|
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
|
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
|
xt, self._scale_timesteps(t), **model_kwargs)
|
|
|
|
# eps -> x0
|
|
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
|
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
|
|
|
# derive variables
|
|
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
|
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
|
alphas = _i(self.alphas_cumprod, t, xt)
|
|
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
|
a = (1 - alphas_prev) / (1 - alphas)
|
|
b = (1 - alphas / alphas_prev)
|
|
sigmas = eta * torch.sqrt(a * b)
|
|
|
|
# random sample
|
|
noise = torch.randn_like(xt)
|
|
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
|
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
|
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
|
|
|
noise.cpu()
|
|
direction.cpu()
|
|
mask.cpu()
|
|
alphas.cpu()
|
|
alphas_prev.cpu()
|
|
sigmas.cpu()
|
|
a.cpu()
|
|
b.cpu()
|
|
eps.cpu()
|
|
x0.cpu()
|
|
noise = None
|
|
direction = None
|
|
mask = None
|
|
alphas = None
|
|
alphas_prev = None
|
|
sigmas = None
|
|
a = None
|
|
b = None
|
|
eps = None
|
|
x0 = None
|
|
del noise
|
|
del direction
|
|
del mask
|
|
del alphas
|
|
del alphas_prev
|
|
del sigmas
|
|
del a
|
|
del b
|
|
del eps
|
|
del x0
|
|
return xt_1
|
|
|
|
@torch.no_grad()
|
|
def ddim_sample_loop(self,
|
|
noise,
|
|
model,
|
|
c=None,
|
|
uc=None,
|
|
num_sample=1,
|
|
clamp=None,
|
|
percentile=None,
|
|
condition_fn=None,
|
|
guide_scale=None,
|
|
ddim_timesteps=20,
|
|
eta=0.0,
|
|
skip_steps=0,
|
|
mask=None,
|
|
):
|
|
|
|
# prepare input
|
|
b = noise.size(0)
|
|
xt = noise
|
|
|
|
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
|
steps = (1 + torch.arange(0, self.num_timesteps,
|
|
self.num_timesteps // ddim_timesteps)).clamp(
|
|
0, self.num_timesteps - 1).flip(0)
|
|
|
|
state.sampling_steps = ddim_timesteps
|
|
|
|
if skip_steps > 0:
|
|
step0 = steps[skip_steps-1]
|
|
steps = steps[skip_steps:]
|
|
|
|
noise_to_add = torch.randn_like(xt)
|
|
t = torch.full((b, ), step0, dtype=torch.long, device=xt.device)
|
|
print("huh", step0, t)
|
|
xt = self.add_noise(xt, noise_to_add, step0)
|
|
state.sampling_steps = state.sampling_steps - skip_steps
|
|
|
|
if mask is not None:
|
|
pass
|
|
step0 = steps[0]
|
|
original_latents=xt
|
|
noise_to_add = torch.randn_like(xt)
|
|
xt = self.add_noise(xt, noise_to_add, step0)
|
|
#convert mask to 0,1 valued based on step
|
|
v=0
|
|
binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask))
|
|
#print("about to die",xt,original_latents,mask,binary_mask)
|
|
|
|
|
|
pbar = tqdm(steps, desc="DDIM sampling")
|
|
|
|
#print(c)
|
|
#print(uc)
|
|
|
|
i = 0
|
|
for step in pbar:
|
|
|
|
state.sampling_step = i
|
|
|
|
if state.interrupted:
|
|
raise InterruptedException
|
|
|
|
c_i = reconstruct_cond_batch(c, i)
|
|
uc_i = reconstruct_cond_batch(uc, i)
|
|
|
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
|
# filling unconditional_conditioning with repeats of the last vector to match length is
|
|
# not 100% correct but should work well enough
|
|
if uc_i.shape[1] < c_i.shape[1]:
|
|
last_vector = uc_i[:, -1:]
|
|
last_vector_repeated = last_vector.repeat([1, c_i.shape[1] - uc.shape[1], 1])
|
|
uc_i = torch.hstack([uc_i, last_vector_repeated])
|
|
elif uc_i.shape[1] > c_i.shape[1]:
|
|
uc_i = uc_i[:, :c_i.shape[1]]
|
|
|
|
#print(c_i.shape, uc_i.shape)
|
|
|
|
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
|
uc_i = uc_i.type(torch.float16)
|
|
c_i = c_i.type(torch.float16)
|
|
#print(uc_i)
|
|
#print(c_i)
|
|
model_kwargs=[{
|
|
'y':
|
|
c_i,
|
|
}, {
|
|
'y':
|
|
uc_i,
|
|
}]
|
|
xt = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
|
percentile, condition_fn, guide_scale,
|
|
ddim_timesteps, eta)
|
|
|
|
|
|
#inpainting
|
|
if mask is not None and i<len(steps)-1:
|
|
v=(ddim_timesteps-i-1)/ddim_timesteps
|
|
binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask))
|
|
|
|
noise_to_add = torch.randn_like(xt)
|
|
#noise_to_add=xt
|
|
to_inpaint=self.add_noise(original_latents, noise_to_add, steps[i+1])
|
|
xt=to_inpaint*(1-binary_mask)+xt*binary_mask
|
|
#print(mask.shape,i,ddim_timesteps,v)
|
|
#print(mask[0,0,:,0,0])
|
|
#print(binary_mask[0,0,:,0,0])
|
|
pass
|
|
|
|
|
|
t.cpu()
|
|
t = None
|
|
i += 1
|
|
pbar.set_description(f"DDIM sampling {str(step)}")
|
|
|
|
if state.skipped:
|
|
break
|
|
pbar.close()
|
|
return xt
|
|
|
|
def _scale_timesteps(self, t):
|
|
if self.rescale_timesteps:
|
|
return t.float() * 1000.0 / self.num_timesteps
|
|
return t
|
|
|
|
|
|
class AutoencoderKL(nn.Module):
|
|
|
|
def __init__(self,
|
|
ddconfig,
|
|
embed_dim,
|
|
ckpt_path=None,
|
|
image_key='image',
|
|
colorize_nlabels=None,
|
|
monitor=None,
|
|
ema_decay=None,
|
|
learn_logvar=False):
|
|
super().__init__()
|
|
self.learn_logvar = learn_logvar
|
|
self.image_key = image_key
|
|
self.encoder = Encoder(**ddconfig)
|
|
self.decoder = Decoder(**ddconfig)
|
|
assert ddconfig['double_z']
|
|
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
|
2 * embed_dim, 1)
|
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
|
ddconfig['z_channels'], 1)
|
|
self.embed_dim = embed_dim
|
|
if colorize_nlabels is not None:
|
|
assert type(colorize_nlabels) == int
|
|
self.register_buffer('colorize',
|
|
torch.randn(3, colorize_nlabels, 1, 1))
|
|
if monitor is not None:
|
|
self.monitor = monitor
|
|
|
|
self.use_ema = ema_decay is not None
|
|
|
|
if ckpt_path is not None:
|
|
self.init_from_ckpt(ckpt_path)
|
|
|
|
def init_from_ckpt(self, path):
|
|
sd = torch.load(path, map_location='cpu')['state_dict']
|
|
keys = list(sd.keys())
|
|
|
|
import collections
|
|
sd_new = collections.OrderedDict()
|
|
|
|
for k in keys:
|
|
if k.find('first_stage_model') >= 0:
|
|
k_new = k.split('first_stage_model.')[-1]
|
|
sd_new[k_new] = sd[k]
|
|
|
|
self.load_state_dict(sd_new, strict=True)
|
|
del sd
|
|
del sd_new
|
|
torch_gc()
|
|
|
|
def on_train_batch_end(self, *args, **kwargs):
|
|
if self.use_ema:
|
|
self.model_ema(self)
|
|
|
|
def encode(self, x):
|
|
h = self.encoder(x)
|
|
moments = self.quant_conv(h)
|
|
posterior = DiagonalGaussianDistribution(moments)
|
|
return posterior
|
|
|
|
def decode(self, z):
|
|
z = self.post_quant_conv(z)
|
|
dec = self.decoder(z)
|
|
return dec
|
|
|
|
def forward(self, input, sample_posterior=True):
|
|
posterior = self.encode(input)
|
|
if sample_posterior:
|
|
z = posterior.sample()
|
|
else:
|
|
z = posterior.mode()
|
|
dec = self.decode(z)
|
|
return dec, posterior
|
|
|
|
def get_input(self, batch, k):
|
|
x = batch[k]
|
|
if len(x.shape) == 3:
|
|
x = x[..., None]
|
|
x = x.permute(0, 3, 1,
|
|
2).to(memory_format=torch.contiguous_format).float()
|
|
return x
|
|
|
|
def get_last_layer(self):
|
|
return self.decoder.conv_out.weight
|
|
|
|
@torch.no_grad()
|
|
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
|
log = dict()
|
|
x = self.get_input(batch, self.image_key)
|
|
x = x.to(self.device)
|
|
if not only_inputs:
|
|
xrec, posterior = self(x)
|
|
if x.shape[1] > 3:
|
|
# colorize with random projection
|
|
assert xrec.shape[1] > 3
|
|
x = self.to_rgb(x)
|
|
xrec = self.to_rgb(xrec)
|
|
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
|
log['reconstructions'] = xrec
|
|
if log_ema or self.use_ema:
|
|
with self.ema_scope():
|
|
xrec_ema, posterior_ema = self(x)
|
|
if x.shape[1] > 3:
|
|
# colorize with random projection
|
|
assert xrec_ema.shape[1] > 3
|
|
xrec_ema = self.to_rgb(xrec_ema)
|
|
log['samples_ema'] = self.decode(
|
|
torch.randn_like(posterior_ema.sample()))
|
|
log['reconstructions_ema'] = xrec_ema
|
|
log['inputs'] = x
|
|
return log
|
|
|
|
def to_rgb(self, x):
|
|
assert self.image_key == 'segmentation'
|
|
if not hasattr(self, 'colorize'):
|
|
self.register_buffer('colorize',
|
|
torch.randn(3, x.shape[1], 1, 1).to(x))
|
|
x = F.conv2d(x, weight=self.colorize)
|
|
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
|
return x
|
|
|
|
|
|
def prob_mask_like(shape, prob, device):
|
|
if prob == 1:
|
|
return torch.ones(shape, device=device, dtype=torch.bool)
|
|
elif prob == 0:
|
|
return torch.zeros(shape, device=device, dtype=torch.bool)
|
|
else:
|
|
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
|
# aviod mask all, which will cause find_unused_parameters error
|
|
if mask.all():
|
|
mask[0] = False
|
|
return mask
|
|
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D convolution module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.Conv1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.Conv3d(*args, **kwargs)
|
|
raise ValueError(f'unsupported dimensions: {dims}')
|
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D average pooling module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.AvgPool1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.AvgPool2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.AvgPool3d(*args, **kwargs)
|
|
raise ValueError(f'unsupported dimensions: {dims}')
|