automatic/modules/interposer.py

165 lines
6.6 KiB
Python

# converted from <https://github.com/city96/SD-Latent-Interposer>
import os
import time
import torch
import torch.nn as nn
from safetensors.torch import load_file
# v1 = Stable Diffusion 1.x
# xl = Stable Diffusion Extra Large (SDXL)
# v3 = Stable Diffusion Version Three (SD3)
# fx = Black Forest Labs Flux dot One
# cc = Stable Cascade (Stage C) [not used]
# ca = Stable Cascade (Stage A/B)
config = {
"v1-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v1-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-v1": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-xl": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-v1": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-xl": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-v3": {"ch_in":16, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"ca-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 0.5, "blocks": 12},
}
class ResBlock(nn.Module):
"""Block with residuals"""
def __init__(self, ch):
super().__init__()
self.join = nn.ReLU()
self.norm = nn.BatchNorm2d(ch)
self.long = nn.Sequential(
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
x = self.norm(x)
return self.join(self.long(x) + x)
class ExtractBlock(nn.Module):
"""Increase no. of channels by [out/in]"""
def __init__(self, ch_in, ch_out):
super().__init__()
self.join = nn.ReLU()
self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.long = nn.Sequential(
nn.Conv2d( ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
return self.join(self.long(x) + self.short(x))
class InterposerModel(nn.Module):
"""
NN layout, ported from:
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
"""
def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.ch_mid = ch_mid
self.blocks = blocks
self.scale = scale
self.head = ExtractBlock(self.ch_in, self.ch_mid)
self.core = nn.Sequential(
nn.Upsample(scale_factor=self.scale, mode="nearest"),
*[ResBlock(self.ch_mid) for _ in range(blocks)],
nn.BatchNorm2d(self.ch_mid),
nn.SiLU(),
)
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)
def forward(self, x):
y = self.head(x)
z = self.core(y)
return self.tail(z)
def map_model_name(name: str):
if name == 'sd':
return 'v1'
if name == 'sdxl':
return 'xl'
if name == 'sd3':
return 'v3'
if name in ['f1', 'chroma']:
return 'fx'
return name
class Interposer:
def __init__(self):
self.version = 4.0 # network revision
self.loaded = None # current model name
self.model = None # current model
self.vae = None # current VAE
def convert(self, src: str, dst: str, latents: torch.Tensor):
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from modules import shared, devices
src = map_model_name(src)
dst = map_model_name(dst)
if src == dst:
return None
model_name = f"{src}-to-{dst}"
if model_name not in config:
shared.log.error(f'Interposer: model="{model_name}" unknown')
return None
if (self.loaded != model_name) or (self.model is None):
model_fn = hf_hub_download(
repo_id="city96/SD-Latent-Interposer",
subfolder=f"v{self.version}",
filename=f"{model_name}_interposer-v{self.version}.safetensors",
cache_dir=shared.opts.hfcache_dir,
)
self.model = InterposerModel(**config[model_name])
self.model = self.model.to(device=devices.cpu, dtype=torch.float32)
self.model.eval()
self.model.load_state_dict(load_file(model_fn))
self.loaded = model_name
if dst == 'v1':
vae_repo = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'xl':
vae_repo = 'madebyollin/sdxl-vae-fp16-fix'
self.vae = AutoencoderKL.from_pretrained(vae_repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'v3':
vae_repo = 'stabilityai/stable-diffusion-3.5-large'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'fx':
vae_repo = 'black-forest-labs/FLUX.1-dev'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
t0 = time.time()
if self.model is None or self.vae is None:
return None
with torch.no_grad():
latent = latents.clone().cpu().float() # force fp32, always run on CPU
output = self.model(latent)
output = output.to(device=latents.device, dtype=latents.dtype)
t1 = time.time()
shared.log.debug(f'Interposer: src={src}/{list(latents.shape)} dst={dst}/{list(output.shape)} model="{os.path.basename(model_fn)}" vae="{vae_repo}" time={t1-t0:.2f}')
# shared.log.debug(f'Interposer: src={latents.aminmax()} dst={output.aminmax()}')
return output