automatic/modules/postprocess/seedvr_model.py

171 lines
7.3 KiB
Python

import time
import random
import numpy as np
import torch
from PIL import Image
from modules import devices, images_sharpfin
from modules.shared import opts, log
from modules.upscaler import Upscaler, UpscalerData
MODELS_MAP = {
"SeedVR2 3B": "seedvr2_ema_3b_fp16.safetensors",
"SeedVR2 7B": "seedvr2_ema_7b_fp16.safetensors",
"SeedVR2 7B Sharp": "seedvr2_ema_7b_sharp_fp16.safetensors",
}
to_pil = images_sharpfin.to_pil
class UpscalerSeedVR(Upscaler):
def __init__(self, dirname=None):
self.name = "SeedVR2"
super().__init__()
self.scalers = [
UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1),
UpscalerData(name="SeedVR2 7B", path=None, upscaler=self, model=None, scale=1),
UpscalerData(name="SeedVR2 7B Sharp", path=None, upscaler=self, model=None, scale=1),
]
self.model = None
self.model_loaded = None
def load_model(self, path: str):
model_name = MODELS_MAP.get(path, None)
if (self.model is None) or (self.model_loaded != model_name):
log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"')
t0 = time.time()
from modules.seedvr.src.core.model_manager import configure_runner
from modules.seedvr.src.core import generation
self.model = configure_runner(
model_name=model_name,
cache_dir=opts.hfcache_dir,
device=devices.device,
dtype=devices.dtype,
)
self.model_loaded = model_name
self.model.dit.device = devices.device
self.model.dit.dtype = devices.dtype
self.model.vae_encode = self.vae_encode
self.model.vae_decode = self.vae_decode
self.model.model_step = generation.generation_step
generation.generation_step = self.model_step
self.model._internal_dict = {
'dit': self.model.dit,
'vae': self.model.vae,
}
t1 = time.time()
self.model.dit.config = self.model.config.dit
self.model.vae.tile_sample_min_size = 1024
self.model.vae.tile_latent_min_size = 128
from modules.model_quant import do_post_load_quant
self.model = do_post_load_quant(self.model, allow=True)
# from modules.sd_offload import set_diffuser_offload
# set_diffuser_offload(self.model)
log.info(f'Upscaler loaded: name="{self.name}" model="{model_name}" time={t1 - t0:.2f}')
def vae_encode(self, samples):
log.debug(f'Upscaler encode: samples={samples[0].shape if len(samples) > 0 else None} tile={self.model.vae.tile_sample_min_size} overlap={self.model.vae.tile_overlap_factor}')
latents = []
if len(samples) == 0:
return latents
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
batches = [sample.unsqueeze(0) for sample in samples]
for sample in batches:
sample = sample.to(self.device, self.model.vae.dtype)
sample = self.model.vae.preprocess(sample)
latent = self.model.vae.encode(sample).latent
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
latents.append(latent)
latents = [latent.squeeze(0) for latent in latents]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return latents
def vae_decode(self, latents, target_dtype: torch.dtype = None):
log.debug(f'Upscaler decode: latents={latents[0].shape if len(latents) > 0 else None} tile={self.model.vae.tile_latent_min_size} overlap={self.model.vae.tile_overlap_factor}')
samples = []
if len(latents) == 0:
return samples
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
latents = [latent.unsqueeze(0) for latent in latents]
with devices.inference_context():
for _i, latent in enumerate(latents):
latent = latent.to(self.device, self.model.vae.dtype)
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
sample = self.model.vae.decode(latent).sample
sample = self.model.vae.postprocess(sample)
samples.append(sample)
samples = [sample.squeeze(0) for sample in samples]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return samples
def model_step(self, *args, **kwargs):
from modules.seedvr.src.optimization import memory_manager
self.model.vae = self.model.vae.to(device="cpu")
self.model.dit = self.model.dit.to(device=self.device)
devices.torch_gc()
log.debug(f'Upscaler inference: args={len(args)} kwargs={list(kwargs.keys())}')
memory_manager.preinitialize_rope_cache(self.model)
with devices.inference_context():
result = self.model.model_step(*args, **kwargs)
self.model.dit = self.model.dit.to(device="cpu")
devices.torch_gc()
return result
def do_upscale(self, img: Image.Image, selected_file):
self.load_model(selected_file)
if self.model is None:
return img
from modules.seedvr.src.core import generation
width = int(self.scale * img.width) // 8 * 8
image_tensor = np.array(img)
image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0
random.seed()
seed = int(random.randrange(4294967294))
t0 = time.time()
with devices.inference_context():
result_tensor = generation.generation_loop(
runner=self.model,
images=image_tensor,
cfg_scale=opts.seedvt_cfg_scale,
seed=seed,
res_w=width,
batch_size=1,
temporal_overlap=0,
device=devices.device,
)
t1 = time.time()
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
img = to_pil(result_tensor.squeeze())
if opts.upscaler_unload:
self.model.dit = None
self.model.vae = None
self.model.cache = None
self.model = None
log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"')
devices.torch_gc(force=True)
return img