# Part of the implementation is borrowed and modified from stable-diffusion, # publicly avaialbe at https://github.com/Stability-AI/stablediffusion. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. # https://github.com/modelscope/modelscope/tree/master/modelscope/pipelines/multi_modal # Alibaba's code used under Apache 2.0 license # StabilityAI's Stable Diffusion code used under MIT license # Automatic1111's WebUI's code used under AGPL v3.0 # All the licenses of the code and its modifications are incorporated into the compatible AGPL v3.0 license # SD-webui text2video: # Copyright (C) 2023 by Artem Khrapov (kabachuha) # See LICENSE for usage terms. from ldm.util import instantiate_from_config import importlib import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from einops import rearrange, repeat from os import path as osp from modules.shared import opts from functools import partial from tqdm import tqdm from modules.prompt_parser import reconstruct_cond_batch from modules.shared import state from modules.sd_samplers_common import InterruptedException from modules.sd_hijack_optimizations import get_xformers_flash_attention_op from ldm.modules.diffusionmodules.util import make_beta_schedule __all__ = ['UNetSD'] try: import gc import torch import torch.cuda def torch_gc(): """Performs garbage collection for both Python and PyTorch CUDA tensors. This function collects Python garbage and clears the PyTorch CUDA cache and IPC (Inter-Process Communication) resources. """ gc.collect() # Collect Python garbage if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear PyTorch CUDA cache torch.cuda.ipc_collect() # Clear PyTorch CUDA IPC resources except: def torch_gc(): """Dummy function when torch is not available. This function does nothing and serves as a placeholder when torch is not available, allowing the rest of the code to run without errors. """ gc.collect() pass import modules.shared as shared from modules.shared import cmd_opts can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution DEFAULT_MODEL_REVISION = None class Invoke(object): KEY = 'invoked_by' PRETRAINED = 'from_pretrained' PIPELINE = 'pipeline' TRAINER = 'trainer' LOCAL_TRAINER = 'local_trainer' PREPROCESSOR = 'preprocessor' def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if callable(d) else d class UNetSD(nn.Module): def __init__(self, in_dim=7, dim=512, y_dim=512, context_dim=512, out_dim=6, dim_mult=[1, 2, 3, 4], num_heads=None, head_dim=64, num_res_blocks=3, attn_scales=[1 / 2, 1 / 4, 1 / 8], use_scale_shift_norm=True, dropout=0.1, temporal_attn_times=2, temporal_attention=True, use_checkpoint=False, use_image_dataset=False, use_fps_condition=False, use_sim_mask=False, parameterization="eps"): embed_dim = dim * 4 num_heads = num_heads if num_heads else dim // 32 super(UNetSD, self).__init__() self.in_dim = in_dim self.dim = dim self.y_dim = y_dim self.context_dim = context_dim self.embed_dim = embed_dim self.out_dim = out_dim self.dim_mult = dim_mult self.num_heads = num_heads # parameters for spatial/temporal attention self.head_dim = head_dim self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.use_scale_shift_norm = use_scale_shift_norm self.temporal_attn_times = temporal_attn_times self.temporal_attention = temporal_attention self.use_checkpoint = use_checkpoint self.use_image_dataset = use_image_dataset self.use_fps_condition = use_fps_condition self.use_sim_mask = use_sim_mask self.parameterization = parameterization self.v_posterior = 0 use_linear_in_temporal = False transformer_depth = 1 disabled_sa = False # params enc_dims = [dim * u for u in [1] + dim_mult] dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] shortcut_dims = [] scale = 1.0 # embeddings self.time_embed = nn.Sequential( nn.Linear(dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim)) if self.use_fps_condition: self.fps_embedding = nn.Sequential( nn.Linear(dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim)) nn.init.zeros_(self.fps_embedding[-1].weight) nn.init.zeros_(self.fps_embedding[-1].bias) # encoder self.input_blocks = nn.ModuleList() init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) if temporal_attention: init_block.append( TemporalTransformer( dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) self.input_blocks.append(init_block) shortcut_dims.append(dim) for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): for j in range(num_res_blocks): # residual (+attention) blocks block = nn.ModuleList([ ResBlock( in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, ) ]) if scale in attn_scales: block.append( SpatialTransformer( out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, disable_self_attn=False, use_linear=True)) if self.temporal_attention: block.append( TemporalTransformer( out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) in_dim = out_dim self.input_blocks.append(block) shortcut_dims.append(out_dim) # downsample if i != len(dim_mult) - 1 and j == num_res_blocks - 1: downsample = Downsample( out_dim, True, dims=2, out_channels=out_dim) shortcut_dims.append(out_dim) scale /= 2.0 self.input_blocks.append(downsample) # middle self.middle_block = nn.ModuleList([ ResBlock( out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, ), SpatialTransformer( out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, disable_self_attn=False, use_linear=True) ]) if self.temporal_attention: self.middle_block.append( TemporalTransformer( out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset, )) self.middle_block.append( ResBlock( out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )) # decoder self.output_blocks = nn.ModuleList() for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): for j in range(num_res_blocks + 1): # residual (+attention) blocks block = nn.ModuleList([ ResBlock( in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, ) ]) if scale in attn_scales: block.append( SpatialTransformer( out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024, disable_self_attn=False, use_linear=True)) if self.temporal_attention: block.append( TemporalTransformer( out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) in_dim = out_dim # upsample if i != len(dim_mult) - 1 and j == num_res_blocks: upsample = Upsample( out_dim, True, dims=2.0, out_channels=out_dim) scale *= 2.0 block.append(upsample) self.output_blocks.append(block) # head self.out = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) # zero out the last layer params nn.init.zeros_(self.out[-1].weight) # Taken from DDPM def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": lvlb_weights = torch.ones_like(self.betas ** 2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() def forward( self, x, t, y, fps=None, video_mask=None, focus_present_mask=None, prob_focus_present=0.0, mask_last_frame_num=0 # mask last frame num ): """ prob_focus_present: probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) """ batch, device = x.shape[0], x.device self.batch = batch # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored if mask_last_frame_num > 0: focus_present_mask = None video_mask[-mask_last_frame_num:] = False else: focus_present_mask = default( focus_present_mask, lambda: prob_mask_like( (batch, ), prob_focus_present, device=device)) time_rel_pos_bias = None # embeddings if self.use_fps_condition and fps is not None: e = self.time_embed(sinusoidal_embedding( t, self.dim)) + self.fps_embedding( sinusoidal_embedding(fps, self.dim)) else: e = self.time_embed(sinusoidal_embedding(t, self.dim)) context = y # repeat f times for spatial e and context f = x.shape[2] e = e.repeat_interleave(repeats=f, dim=0) context = context.repeat_interleave(repeats=f, dim=0) # always in shape (b f) c h w, except for temporal layer x = rearrange(x, 'b c f h w -> (b f) c h w') # encoder xs = [] for block in self.input_blocks: x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask) xs.append(x) # middle for block in self.middle_block: x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask) # decoder for block in self.output_blocks: x = torch.cat([x, xs.pop()], dim=1) x = self._forward_single( block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None) # head x = self.out(x) # reshape back to (b c f h w) x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) return x def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None): if isinstance(module, ResidualBlock): x = x.contiguous() x = module(x, e, reference) elif isinstance(module, ResBlock): x = x.contiguous() x = module(x, e, self.batch) elif isinstance(module, SpatialTransformer): x = module(x, context) elif isinstance(module, TemporalTransformer): x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) x = module(x, context) x = rearrange(x, 'b c f h w -> (b f) c h w') elif isinstance(module, CrossAttention): x = module(x, context) elif isinstance(module, BasicTransformerBlock): x = module(x, context) elif isinstance(module, FeedForward): x = module(x, context) elif isinstance(module, Upsample): x = module(x) elif isinstance(module, Downsample): x = module(x) elif isinstance(module, Resample): x = module(x, reference) elif isinstance(module, nn.ModuleList): for block in module: x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference) else: x = module(x) return x def sinusoidal_embedding(timesteps, dim): # check input half = dim // 2 timesteps = timesteps.float() # compute sinusoidal embedding sinusoid = torch.outer( timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) if dim % 2 != 0: x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) return x class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(x.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): import xformers out = xformers.ops.memory_efficient_attention( q, k, v, op=get_xformers_flash_attention_op(q,k,v), attn_bias=mask, ) elif getattr(cmd_opts, "opt_sdp_no_mem_attention", False) and can_use_sdp: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): out = F.scaled_dot_product_attention( q, k, v, dropout_p=0.0, attn_mask=mask ) elif getattr(cmd_opts, "opt_sdp_attention", True) and can_use_sdp: out = F.scaled_dot_product_attention( q, k, v, dropout_p=0.0, attn_mask=mask ) else: sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k if exists(mask): sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = torch.einsum('b i j, b j d -> b i d', sim, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) class SpatialTransformer(nn.Module): """ Transformer block for image-like data in spatial axis. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) if not use_linear: self.proj_in = nn.Conv2d( in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList([ BasicTransformerBlock( inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth) ]) if not use_linear: self.proj_out = zero_module( nn.Conv2d( inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c').contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in class TemporalTransformer(nn.Module): """ Transformer block for image-like data in temporal axis. First, reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True, only_self_att=True, multiply_zero=False): super().__init__() self.multiply_zero = multiply_zero self.only_self_att = only_self_att if self.only_self_att: context_dim = None if not isinstance(context_dim, list): context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) if not use_linear: self.proj_in = nn.Conv1d( in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList([ BasicTransformerBlock( inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], checkpoint=use_checkpoint) for d in range(depth) ]) if not use_linear: self.proj_out = zero_module( nn.Conv1d( inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention if self.only_self_att: context = None if not isinstance(context, list): context = [context] b, c, f, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() x = self.proj_in(x) if self.use_linear: x = rearrange( x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() x = self.proj_in(x) if self.only_self_att: x = rearrange(x, 'bhw c f -> bhw f c').contiguous() for i, block in enumerate(self.transformer_blocks): x = block(x) x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() else: x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() for i, block in enumerate(self.transformer_blocks): context[i] = rearrange( context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) for j in range(b): context_i_j = repeat( context[i][j], 'f l con -> (f r) l con', r=(h * w) // self.frames, f=self.frames).contiguous() x[j] = block(x[j], context=context_i_j) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() if not self.use_linear: x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() x = self.proj_out(x) x = rearrange( x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() if self.multiply_zero: x = 0.0 * x + x_in else: x = x + x_in return x class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): super().__init__() attn_cls = CrossAttention self.disable_self_attn = disable_self_attn self.attn1 = attn_cls( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = attn_cls( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): x = self.attn1( self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential(nn.Linear( dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = nn.Conv2d( self.channels, self.out_channels, 3, padding=padding) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') else: x = F.interpolate(x, scale_factor=2, mode='nearest') if self.use_conv: x = self.conv(x) return x class ResBlock(nn.Module): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. :param use_temporal_conv: if True, use the temporal convolution. :param use_image_dataset: if True, the temporal parameters will not be optimized. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, up=False, down=False, use_temporal_conv=True, use_image_dataset=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm self.use_temporal_conv = use_temporal_conv self.in_layers = nn.Sequential( nn.GroupNorm(32, channels), nn.SiLU(), nn.Conv2d(channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) if self.use_temporal_conv: self.temopral_conv = TemporalConvBlock_v2( self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) def forward(self, x, emb, batch_size): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ return self._forward(x, emb, batch_size) def _forward(self, x, emb, batch_size): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) h = self.skip_connection(x) + h if self.use_temporal_conv: h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) h = self.temopral_conv(h) h = rearrange(h, 'b c f h w -> (b f) c h w') return h class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if self.use_conv: self.op = nn.Conv2d( self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class Resample(nn.Module): def __init__(self, in_dim, out_dim, mode): assert mode in ['none', 'upsample', 'downsample'] super(Resample, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.mode = mode def forward(self, x, reference=None): if self.mode == 'upsample': assert reference is not None x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') elif self.mode == 'downsample': x = F.adaptive_avg_pool2d( x, output_size=tuple(u // 2 for u in x.shape[-2:])) return x class ResidualBlock(nn.Module): def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True, mode='none', dropout=0.0): super(ResidualBlock, self).__init__() self.in_dim = in_dim self.embed_dim = embed_dim self.out_dim = out_dim self.use_scale_shift_norm = use_scale_shift_norm self.mode = mode # layers self.layer1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv2d(in_dim, out_dim, 3, padding=1)) self.resample = Resample(in_dim, in_dim, mode) self.embedding = nn.Sequential( nn.SiLU(), nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim)) self.layer2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv2d(out_dim, out_dim, 3, padding=1)) self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( in_dim, out_dim, 1) # zero out the last layer params nn.init.zeros_(self.layer2[-1].weight) def forward(self, x, e, reference=None): identity = self.resample(x, reference) x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) if self.use_scale_shift_norm: scale, shift = e.chunk(2, dim=1) x = self.layer2[0](x) * (1 + scale) + shift x = self.layer2[1:](x) else: x = x + e x = self.layer2(x) x = x + self.shortcut(identity) return x class AttentionBlock(nn.Module): def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): # consider head_dim first, then num_heads num_heads = dim // head_dim if head_dim else num_heads head_dim = dim // num_heads assert num_heads * head_dim == dim super(AttentionBlock, self).__init__() self.dim = dim self.context_dim = context_dim self.num_heads = num_heads self.head_dim = head_dim self.scale = math.pow(head_dim, -0.25) # layers self.norm = nn.GroupNorm(32, dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) if context_dim is not None: self.context_kv = nn.Linear(context_dim, dim * 2) self.proj = nn.Conv2d(dim, dim, 1) # zero out the last layer params nn.init.zeros_(self.proj.weight) def forward(self, x, context=None): r"""x: [B, C, H, W]. context: [B, L, C] or None. """ identity = x b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value x = self.norm(x) q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) if context is not None: ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk( 2, dim=1) k = torch.cat([ck, k], dim=-1) v = torch.cat([cv, v], dim=-1) # compute attention if getattr(cmd_opts, "force_enable_xformers", False) or (getattr(cmd_opts, "xformers", False) and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): import xformers x = xformers.ops.memory_efficient_attention( q, k, v, op=get_xformers_flash_attention_op(q,k,v), ) elif getattr(cmd_opts, "opt_sdp_no_mem_attention", False) and can_use_sdp: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): x = F.scaled_dot_product_attention( q, k, v, dropout_p=0.0, ) elif getattr(cmd_opts, "opt_sdp_attention", True) and can_use_sdp: x = F.scaled_dot_product_attention( q, k, v, dropout_p=0.0, ) else: attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) attn = F.softmax(attn, dim=-1) # gather context x = torch.matmul(v, attn.transpose(-1, -2)) x = x.reshape(b, c, h, w) # output x = self.proj(x) return x + identity class TemporalConvBlock_v2(nn.Module): def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): super(TemporalConvBlock_v2, self).__init__() if out_dim is None: out_dim = in_dim # int(1.5*in_dim) self.in_dim = in_dim self.out_dim = out_dim self.use_image_dataset = use_image_dataset # conv layers self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) def forward(self, x): identity = x x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) if self.use_image_dataset: x = identity + 0.0 * x else: x = identity + x return x def _i(tensor, t, x): r"""Index tensor using t and format the output according to x. """ tensor = tensor.to(x.device) shape = (x.size(0), ) + (1, ) * (x.ndim - 1) return tensor[t].view(shape).to(x) def beta_schedule(schedule, num_timesteps=1000, init_beta=None, last_beta=None): if schedule == 'linear_sd': return torch.linspace( init_beta**0.5, last_beta**0.5, num_timesteps, dtype=torch.float64)**2 else: raise ValueError(f'Unsupported schedule: {schedule}') class GaussianDiffusion(object): r""" Diffusion Model for DDIM. "Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon. See https://arxiv.org/abs/2010.02502 """ def __init__(self, betas, mean_type='eps', var_type='learned_range', loss_type='mse', epsilon=1e-12, rescale_timesteps=False): # check input if not isinstance(betas, torch.DoubleTensor): betas = torch.tensor(betas, dtype=torch.float64) assert min(betas) > 0 and max(betas) <= 1 assert mean_type in ['x0', 'x_{t-1}', 'eps'] assert var_type in [ 'learned', 'learned_range', 'fixed_large', 'fixed_small' ] assert loss_type in [ 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1', 'charbonnier' ] self.betas = betas self.num_timesteps = len(betas) self.mean_type = mean_type self.var_type = var_type self.loss_type = loss_type self.epsilon = epsilon self.rescale_timesteps = rescale_timesteps # alphas alphas = 1 - self.betas self.alphas_cumprod = torch.cumprod(alphas, dim=0) self.alphas_cumprod_prev = torch.cat( [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) self.alphas_cumprod_next = torch.cat( [self.alphas_cumprod[1:], alphas.new_zeros([1])]) # q(x_t | x_{t-1}) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) # q(x_{t-1} | x_t, x_0) self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( 1.0 - self.alphas_cumprod) self.posterior_log_variance_clipped = torch.log( self.posterior_variance.clamp(1e-20)) self.posterior_mean_coef1 = betas * torch.sqrt( self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.posterior_mean_coef2 = ( 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( 1.0 - self.alphas_cumprod) def add_noise(self, xt, noise, t): #print("adding noise", t, # self.sqrt_alphas_cumprod[t], self.sqrt_one_minus_alphas_cumprod[t]) noisy_sample = self.sqrt_alphas_cumprod[t] * \ xt+noise*self.sqrt_one_minus_alphas_cumprod[t] return noisy_sample def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None): r"""Distribution of p(x_{t-1} | x_t). """ # predict distribution if guide_scale is None or guide_scale == 1: out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) else: # classifier-free guidance # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) dim = y_out.size(1) if self.var_type.startswith( 'fixed') else y_out.size(1) // 2 a = u_out[:, :dim] b = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) c = y_out[:, dim:] out = torch.cat([a + b, c], dim=1) # compute variance if self.var_type == 'fixed_small': var = _i(self.posterior_variance, t, xt) log_var = _i(self.posterior_log_variance_clipped, t, xt) # compute mean and x0 if self.mean_type == 'eps': x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( self.sqrt_recipm1_alphas_cumprod, t, xt) * out mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) # restrict the range of x0 if percentile is not None: assert percentile > 0 and percentile <= 1 # e.g., 0.995 s = torch.quantile( x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) x0 = torch.min(s, torch.max(-s, x0)) / s elif clamp is not None: x0 = x0.clamp(-clamp, clamp) return mu, var, log_var, x0 def q_posterior_mean_variance(self, x0, xt, t): r"""Distribution of q(x_{t-1} | x_t, x_0). """ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( self.posterior_mean_coef2, t, xt) * xt var = _i(self.posterior_variance, t, xt) log_var = _i(self.posterior_log_variance_clipped, t, xt) return mu, var, log_var @torch.no_grad() def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): r"""Sample from p(x_{t-1} | x_t) using DDIM. - condition_fn: for classifier-based guidance (guided-diffusion). - guide_scale: for classifier-free guidance (glide/dalle-2). """ stride = self.num_timesteps // ddim_timesteps # predict distribution of p(x_{t-1} | x_t) _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) if condition_fn is not None: # x0 -> eps alpha = _i(self.alphas_cumprod, t, xt) eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( self.sqrt_recipm1_alphas_cumprod, t, xt) eps = eps - (1 - alpha).sqrt() * condition_fn( xt, self._scale_timesteps(t), **model_kwargs) # eps -> x0 x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # derive variables eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( self.sqrt_recipm1_alphas_cumprod, t, xt) alphas = _i(self.alphas_cumprod, t, xt) alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) a = (1 - alphas_prev) / (1 - alphas) b = (1 - alphas / alphas_prev) sigmas = eta * torch.sqrt(a * b) # random sample noise = torch.randn_like(xt) direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise noise.cpu() direction.cpu() mask.cpu() alphas.cpu() alphas_prev.cpu() sigmas.cpu() a.cpu() b.cpu() eps.cpu() x0.cpu() noise = None direction = None mask = None alphas = None alphas_prev = None sigmas = None a = None b = None eps = None x0 = None del noise del direction del mask del alphas del alphas_prev del sigmas del a del b del eps del x0 return xt_1 @torch.no_grad() def ddim_sample_loop(self, noise, model, c=None, uc=None, num_sample=1, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, skip_steps=0, mask=None, ): # prepare input b = noise.size(0) xt = noise # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp( 0, self.num_timesteps - 1).flip(0) state.sampling_steps = ddim_timesteps if skip_steps > 0: step0 = steps[skip_steps-1] steps = steps[skip_steps:] noise_to_add = torch.randn_like(xt) t = torch.full((b, ), step0, dtype=torch.long, device=xt.device) print("huh", step0, t) xt = self.add_noise(xt, noise_to_add, step0) state.sampling_steps = state.sampling_steps - skip_steps if mask is not None: pass step0 = steps[0] original_latents=xt noise_to_add = torch.randn_like(xt) xt = self.add_noise(xt, noise_to_add, step0) #convert mask to 0,1 valued based on step v=0 binary_mask = torch.where(mask <= v, torch.zeros_like(mask), torch.ones_like(mask)) #print("about to die",xt,original_latents,mask,binary_mask) pbar = tqdm(steps, desc="DDIM sampling") #print(c) #print(uc) i = 0 for step in pbar: state.sampling_step = i if state.interrupted: raise InterruptedException c_i = reconstruct_cond_batch(c, i) uc_i = reconstruct_cond_batch(uc, i) # for DDIM, shapes must match, we can't just process cond and uncond independently; # filling unconditional_conditioning with repeats of the last vector to match length is # not 100% correct but should work well enough if uc_i.shape[1] < c_i.shape[1]: last_vector = uc_i[:, -1:] last_vector_repeated = last_vector.repeat([1, c_i.shape[1] - uc.shape[1], 1]) uc_i = torch.hstack([uc_i, last_vector_repeated]) elif uc_i.shape[1] > c_i.shape[1]: uc_i = uc_i[:, :c_i.shape[1]] #print(c_i.shape, uc_i.shape) t = torch.full((b, ), step, dtype=torch.long, device=xt.device) uc_i = uc_i.type(torch.float16) c_i = c_i.type(torch.float16) #print(uc_i) #print(c_i) model_kwargs=[{ 'y': c_i, }, { 'y': uc_i, }] xt = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta) #inpainting if mask is not None and i= 0: k_new = k.split('first_stage_model.')[-1] sd_new[k_new] = sd[k] self.load_state_dict(sd_new, strict=True) del sd del sd_new torch_gc() def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() return x def get_last_layer(self): return self.decoder.conv_out.weight @torch.no_grad() def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: xrec, posterior = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) log['samples'] = self.decode(torch.randn_like(posterior.sample())) log['reconstructions'] = xrec if log_ema or self.use_ema: with self.ema_scope(): xrec_ema, posterior_ema = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec_ema.shape[1] > 3 xrec_ema = self.to_rgb(xrec_ema) log['samples_ema'] = self.decode( torch.randn_like(posterior_ema.sample())) log['reconstructions_ema'] = xrec_ema log['inputs'] = x return log def to_rgb(self, x): assert self.image_key == 'segmentation' if not hasattr(self, 'colorize'): self.register_buffer('colorize', torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob # aviod mask all, which will cause find_unused_parameters error if mask.all(): mask[0] = False return mask def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}') def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}')