mirror of https://github.com/vladmandic/automatic
557 lines
17 KiB
Python
557 lines
17 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from .modules.norm import RMSNorm
|
|
from .modules.transformer import TransformerBlock
|
|
from .modules.patch import PatchEmbed, UnPatch
|
|
from .modules.axial_rope import make_axial_pos
|
|
from .modules.time_emb import TimestepEmbedding
|
|
from .modules.norm import RMSNorm, DyT
|
|
from .utils import isiterable
|
|
|
|
|
|
class TBackBone(nn.Module):
|
|
"""
|
|
Basic backbone of transformer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim=1024,
|
|
ctx_dim=1024,
|
|
heads=16,
|
|
dim_head=64,
|
|
mlp_dim=3072,
|
|
pos_dim=2,
|
|
depth=8,
|
|
use_adaln=False,
|
|
use_shared_adaln=False,
|
|
use_dyt=False,
|
|
):
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
TransformerBlock(
|
|
dim,
|
|
ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
pos_dim,
|
|
use_adaln,
|
|
use_shared_adaln,
|
|
norm_layer=DyT if use_dyt else RMSNorm,
|
|
)
|
|
for _ in range(depth)
|
|
]
|
|
)
|
|
self.grad_ckpt = False
|
|
|
|
def init_weight(self):
|
|
for param in self.parameters():
|
|
if param.ndim == 1:
|
|
nn.init.normal_(param, mean=0.0, std=(1 / param.size(0)) ** 0.5)
|
|
elif param.ndim == 2:
|
|
fan_in = param.size(1)
|
|
nn.init.normal_(param, mean=0.0, std=(1 / fan_in) ** 0.5)
|
|
elif param.ndim >= 3:
|
|
fan_out, *fan_ins = param.shape
|
|
# cumprod
|
|
fan_in = 1
|
|
for f in fan_ins:
|
|
fan_in *= f
|
|
nn.init.normal_(param, mean=0.0, std=(1 / fan_in) ** 0.5)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
ctx=None,
|
|
x_mask=None,
|
|
ctx_mask=None,
|
|
pos_map=None,
|
|
y=None,
|
|
shared_adaln=None,
|
|
):
|
|
if pos_map is not None:
|
|
assert pos_map.size(1) == x.size(1)
|
|
|
|
for block in self.blocks:
|
|
if self.grad_ckpt:
|
|
x = checkpoint(
|
|
block,
|
|
x,
|
|
ctx,
|
|
pos_map,
|
|
None,
|
|
y,
|
|
x_mask,
|
|
ctx_mask,
|
|
shared_adaln,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
x = block(x, ctx, pos_map, None, y, x_mask, ctx_mask, shared_adaln)
|
|
|
|
return x
|
|
|
|
|
|
class XUTBackBone(nn.Module):
|
|
"""
|
|
Basic backbone of cross-U-transformer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim=1024,
|
|
ctx_dim=None,
|
|
heads=16,
|
|
dim_head=64,
|
|
mlp_dim=3072,
|
|
pos_dim=2,
|
|
depth=8,
|
|
enc_blocks=1,
|
|
dec_blocks=2,
|
|
dec_ctx=False,
|
|
use_adaln=False,
|
|
use_shared_adaln=False,
|
|
use_dyt=False,
|
|
):
|
|
super().__init__()
|
|
if isiterable(enc_blocks):
|
|
enc_blocks = list(enc_blocks)
|
|
assert len(enc_blocks) == depth
|
|
else:
|
|
enc_blocks = [int(enc_blocks)] * depth
|
|
if isiterable(dec_blocks):
|
|
dec_blocks = list(dec_blocks)
|
|
assert len(dec_blocks) == depth
|
|
else:
|
|
dec_blocks = [int(dec_blocks)] * depth
|
|
|
|
self.enc_blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
blocks = [
|
|
TransformerBlock(
|
|
dim,
|
|
ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
pos_dim,
|
|
use_adaln,
|
|
use_shared_adaln,
|
|
norm_layer=DyT if use_dyt else RMSNorm,
|
|
)
|
|
for _ in range(enc_blocks[i])
|
|
]
|
|
self.enc_blocks.append(nn.ModuleList(blocks))
|
|
|
|
self.dec_ctx = dec_ctx
|
|
self.dec_blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
blocks = [
|
|
TransformerBlock(
|
|
dim,
|
|
dim if bid == 0 else ctx_dim if dec_ctx else None,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
pos_dim,
|
|
use_adaln,
|
|
use_shared_adaln,
|
|
ctx_from_self=bid == 0,
|
|
norm_layer=DyT if use_dyt else RMSNorm,
|
|
)
|
|
for bid in range(dec_blocks[i])
|
|
]
|
|
self.dec_blocks.append(nn.ModuleList(blocks))
|
|
|
|
self.grad_ckpt = False
|
|
|
|
def init_weight(self):
|
|
for param in self.parameters():
|
|
if param.ndim == 1:
|
|
nn.init.normal_(param, mean=0.0, std=(1 / param.size(0)) ** 0.5)
|
|
elif param.ndim == 2:
|
|
fan_in = param.size(1)
|
|
nn.init.normal_(param, mean=0.0, std=(1 / fan_in) ** 0.5)
|
|
elif param.ndim >= 3:
|
|
fan_out, *fan_ins = param.shape
|
|
# cumprod
|
|
fan_in = 1
|
|
for f in fan_ins:
|
|
fan_in *= f
|
|
nn.init.normal_(param, mean=0.0, std=(1 / fan_in) ** 0.5)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
ctx=None,
|
|
x_mask=None,
|
|
ctx_mask=None,
|
|
pos_map=None,
|
|
y=None,
|
|
shared_adaln=None,
|
|
return_enc_out=False,
|
|
):
|
|
if pos_map is not None:
|
|
assert pos_map.size(1) == x.size(1)
|
|
|
|
self_ctx = []
|
|
for blocks in self.enc_blocks:
|
|
for block in blocks:
|
|
if self.grad_ckpt:
|
|
x = checkpoint(
|
|
block,
|
|
x,
|
|
ctx,
|
|
pos_map,
|
|
None,
|
|
y,
|
|
x_mask,
|
|
ctx_mask,
|
|
shared_adaln,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
x = block(x, ctx, pos_map, None, y, x_mask, ctx_mask, shared_adaln)
|
|
self_ctx.append(x)
|
|
enc_out = x
|
|
|
|
for blocks in self.dec_blocks:
|
|
first_block = blocks[0]
|
|
if self.grad_ckpt:
|
|
x = checkpoint(
|
|
first_block,
|
|
x,
|
|
self_ctx[-1],
|
|
pos_map,
|
|
pos_map,
|
|
y,
|
|
x_mask,
|
|
ctx_mask,
|
|
shared_adaln,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
x = first_block(
|
|
x, self_ctx[-1], pos_map, pos_map, y, x_mask, ctx_mask, shared_adaln
|
|
)
|
|
|
|
for block in blocks[1:]:
|
|
if self.grad_ckpt:
|
|
x = checkpoint(
|
|
block,
|
|
x,
|
|
ctx if self.dec_ctx else None,
|
|
pos_map,
|
|
None,
|
|
y,
|
|
x_mask,
|
|
ctx_mask,
|
|
shared_adaln,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
x = block(
|
|
x,
|
|
ctx if self.dec_ctx else None,
|
|
pos_map,
|
|
None,
|
|
y,
|
|
x_mask,
|
|
ctx_mask,
|
|
shared_adaln,
|
|
)
|
|
|
|
if return_enc_out:
|
|
return x, enc_out
|
|
return x
|
|
|
|
|
|
class XUDiT(nn.Module):
|
|
"""
|
|
Xross-U-Transformer for Image Gen (XUDiT).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
patch_size=2,
|
|
input_dim=4,
|
|
dim=1024,
|
|
ctx_dim=1024,
|
|
ctx_size=256,
|
|
heads=16,
|
|
dim_head=64,
|
|
mlp_dim=3072,
|
|
depth=8,
|
|
enc_blocks=1,
|
|
dec_blocks=2,
|
|
dec_ctx=False,
|
|
class_cond=0,
|
|
shared_adaln=True,
|
|
concat_ctx=True,
|
|
use_dyt=False,
|
|
double_t=False,
|
|
addon_info_embs_dim=None,
|
|
tread_config=None,
|
|
):
|
|
super().__init__()
|
|
self.backbone = XUTBackBone(
|
|
dim,
|
|
None if concat_ctx else ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
2,
|
|
depth,
|
|
enc_blocks,
|
|
dec_blocks,
|
|
use_adaln=True,
|
|
use_shared_adaln=shared_adaln,
|
|
dec_ctx=dec_ctx,
|
|
use_dyt=use_dyt,
|
|
)
|
|
|
|
self.use_tread = False
|
|
if tread_config is not None:
|
|
self.use_tread = True
|
|
self.dropout_ratio = tread_config["dropout_ratio"]
|
|
self.prev_tread_trns = TBackBone(
|
|
dim,
|
|
None if concat_ctx else ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
2,
|
|
tread_config["prev_trns_depth"],
|
|
use_adaln=True,
|
|
use_shared_adaln=shared_adaln,
|
|
use_dyt=use_dyt,
|
|
)
|
|
self.post_tread_trns = TBackBone(
|
|
dim,
|
|
None if concat_ctx else ctx_dim,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
2,
|
|
tread_config["post_trns_depth"],
|
|
use_adaln=True,
|
|
use_shared_adaln=shared_adaln,
|
|
use_dyt=use_dyt,
|
|
)
|
|
|
|
self.patch_size = patch_size
|
|
self.in_patch = PatchEmbed(patch_size, input_dim, dim)
|
|
self.out_patch = UnPatch(patch_size, dim, input_dim)
|
|
self.time_emb = TimestepEmbedding(dim)
|
|
if double_t:
|
|
self.r_emb = TimestepEmbedding(dim)
|
|
if shared_adaln:
|
|
self.shared_adaln_attn = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, dim * 4),
|
|
nn.Mish(),
|
|
nn.Linear(dim * 4, dim * 3),
|
|
)
|
|
nn.init.constant_(self.shared_adaln_attn[-1].bias, 0)
|
|
nn.init.constant_(self.shared_adaln_attn[-1].weight, 0)
|
|
self.shared_adaln_xattn = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, dim * 4),
|
|
nn.Mish(),
|
|
nn.Linear(dim * 4, dim * 3),
|
|
)
|
|
nn.init.constant_(self.shared_adaln_xattn[-1].bias, 0)
|
|
nn.init.constant_(self.shared_adaln_xattn[-1].weight, 0)
|
|
self.shared_adaln_ffw = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, dim * 4),
|
|
nn.Mish(),
|
|
nn.Linear(dim * 4, dim * 3),
|
|
)
|
|
nn.init.constant_(self.shared_adaln_ffw[-1].bias, 0)
|
|
nn.init.constant_(self.shared_adaln_ffw[-1].weight, 0)
|
|
if class_cond > 0:
|
|
self.class_token = nn.Embedding(class_cond, dim)
|
|
else:
|
|
self.class_token = None
|
|
if concat_ctx and ctx_dim is not None:
|
|
self.ctx_proj = nn.Linear(ctx_dim, dim)
|
|
else:
|
|
self.ctx_proj = None
|
|
if addon_info_embs_dim is not None:
|
|
self.addon_info_embs_proj = nn.Sequential(
|
|
nn.Linear(addon_info_embs_dim, dim), nn.Mish(), nn.Linear(dim, dim)
|
|
)
|
|
nn.init.constant_(self.addon_info_embs_proj[-1].bias, 0)
|
|
nn.init.constant_(self.addon_info_embs_proj[-1].weight, 0)
|
|
|
|
self.concat_ctx = concat_ctx
|
|
self.shared_adaln = shared_adaln
|
|
self.need_ctx = ctx_dim is not None
|
|
self.ctx_dim = ctx_dim
|
|
self.ctx_size = ctx_size
|
|
self.grad_ckpt = False
|
|
self.init_weight()
|
|
|
|
def init_weight(self):
|
|
if isinstance(self.out_patch.proj, nn.Linear):
|
|
nn.init.normal_(
|
|
self.out_patch.proj.weight,
|
|
mean=0.0,
|
|
std=1 / self.out_patch.proj.in_features**2,
|
|
)
|
|
|
|
def set_grad_ckpt(self, grad_ckpt):
|
|
self.backbone.grad_ckpt = grad_ckpt
|
|
self.grad_ckpt = grad_ckpt
|
|
if self.use_tread:
|
|
self.prev_tread_trns.grad_ckpt = grad_ckpt
|
|
self.post_tread_trns.grad_ckpt = grad_ckpt
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
t,
|
|
ctx=None,
|
|
pos_map=None,
|
|
r=None,
|
|
addon_info=None,
|
|
tread_rate=None,
|
|
return_enc_out=False,
|
|
):
|
|
n, c, h, w = x.size()
|
|
t = t.reshape(n, -1)
|
|
x, pos_map = self.in_patch(x, pos_map)
|
|
x = x.contiguous()
|
|
if pos_map is None:
|
|
pos_map = (
|
|
make_axial_pos(
|
|
h // self.patch_size,
|
|
w // self.patch_size,
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
)
|
|
.unsqueeze(0)
|
|
.expand(n, -1, -1)
|
|
)
|
|
t_emb = self.time_emb(t)
|
|
if r is not None:
|
|
t_emb = t_emb + self.r_emb((t - r.reshape(n, -1)))
|
|
if self.class_token is not None and ctx is not None:
|
|
if ctx.ndim == 1:
|
|
ctx = ctx[:, None]
|
|
t_emb = t_emb + self.class_token(ctx)
|
|
ctx = None
|
|
if addon_info is not None:
|
|
if addon_info.ndim == 1:
|
|
# [B] -> [B, 1] for single value info
|
|
addon_info = addon_info[:, None]
|
|
# [B, D] -> [B, 1, D] for t_emb shape
|
|
addon_embs = self.addon_info_embs_proj(addon_info)[:, None]
|
|
t_emb = t_emb + addon_embs
|
|
if ctx == None and self.need_ctx:
|
|
ctx = torch.zeros(n, self.ctx_size, self.ctx_dim, device=x.device)
|
|
|
|
if self.shared_adaln:
|
|
shared_adaln_state = [
|
|
self.shared_adaln_attn(t_emb).chunk(3, dim=-1),
|
|
self.shared_adaln_xattn(t_emb).chunk(3, dim=-1),
|
|
self.shared_adaln_ffw(t_emb).chunk(3, dim=-1),
|
|
]
|
|
else:
|
|
shared_adaln_state = None
|
|
|
|
length = x.size(1)
|
|
if self.ctx_proj is not None:
|
|
ctx = self.ctx_proj(ctx)
|
|
x = torch.cat([x, ctx], dim=1)
|
|
if pos_map is not None:
|
|
pos_map = torch.cat(
|
|
[
|
|
pos_map,
|
|
torch.zeros(n, ctx.size(1), pos_map.size(2), device=x.device),
|
|
],
|
|
dim=1,
|
|
)
|
|
ctx = None
|
|
|
|
if self.use_tread:
|
|
x = self.prev_tread_trns(
|
|
x,
|
|
ctx=ctx,
|
|
pos_map=pos_map,
|
|
y=t_emb,
|
|
shared_adaln=shared_adaln_state,
|
|
)
|
|
if self.training or tread_rate is not None:
|
|
xt_selection_length = selection_length = length - int(
|
|
length * (tread_rate or self.dropout_ratio)
|
|
)
|
|
selection = torch.stack(
|
|
[
|
|
torch.randperm(length, device=x.device) < selection_length
|
|
for _ in range(n)
|
|
]
|
|
)
|
|
if self.ctx_proj is not None:
|
|
ctx_length = x.size(1) - length
|
|
selection = torch.concat(
|
|
[
|
|
selection,
|
|
torch.ones(
|
|
n, ctx_length, device=x.device, dtype=torch.bool
|
|
),
|
|
],
|
|
dim=1,
|
|
)
|
|
selection_length += ctx_length
|
|
full_length = x.size(1)
|
|
not_masked_part = x[~selection, :]
|
|
masked_part = x[selection, :].unflatten(0, (n, selection_length))
|
|
x = masked_part
|
|
raw_pos_map = pos_map
|
|
pos_map = pos_map[selection, :].unflatten(0, (n, selection_length))
|
|
backbone_out = self.backbone(
|
|
x,
|
|
ctx=ctx,
|
|
pos_map=pos_map,
|
|
y=t_emb,
|
|
shared_adaln=shared_adaln_state,
|
|
return_enc_out=return_enc_out,
|
|
)
|
|
if return_enc_out:
|
|
backbone_out, enc_out = backbone_out
|
|
if self.use_tread:
|
|
if self.training or tread_rate is not None:
|
|
out = torch.empty(
|
|
n, full_length, x.size(2), device=x.device, dtype=x.dtype
|
|
)
|
|
out[~selection, :] = not_masked_part
|
|
out[selection, :] = backbone_out.flatten(0, 1)
|
|
pos_map = raw_pos_map
|
|
else:
|
|
out = backbone_out
|
|
out = self.post_tread_trns(
|
|
out,
|
|
ctx=ctx,
|
|
pos_map=pos_map,
|
|
y=t_emb,
|
|
shared_adaln=shared_adaln_state,
|
|
)
|
|
else:
|
|
out = backbone_out
|
|
out = out[:, :length]
|
|
out = self.out_patch(out, h, w)
|
|
|
|
if return_enc_out:
|
|
length = (
|
|
xt_selection_length if self.use_tread and self.training else full_length
|
|
)
|
|
return out, enc_out[:, :length]
|
|
return out
|