sd-webui-deforum/scripts/deforum_helpers/deforum_controlnet_hardcode.py

194 lines
7.6 KiB
Python

# TODO HACK FIXME HARDCODE — as using the scripts doesn't seem to work for some reason
deforum_latest_network = None
deforum_latest_params = (None, 'placeholder to trigger the model loading')
deforum_input_image = None
from scripts.processor import unload_hed, unload_mlsd, unload_midas, unload_leres, unload_pidinet, unload_openpose, unload_uniformer, HWC3
import modules.shared as shared
import modules.devices as devices
import modules.processing as processing
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img
import numpy as np
from scripts.controlnet import update_cn_models, cn_models, cn_models_names
import os
import modules.scripts as scrpts
import torch
from scripts.cldm import PlugableControlModel
from scripts.adapter import PlugableAdapter
from scripts.utils import load_state_dict
from torchvision.transforms import Resize, InterpolationMode, CenterCrop, Compose
from einops import rearrange
cn_models_dir = os.path.join(scrpts.basedir(), "models")
default_conf_adapter = os.path.join(cn_models_dir, "sketch_adapter_v14.yaml")
default_conf = os.path.join(cn_models_dir, "cldm_v15.yaml")
unloadable = {
"hed": unload_hed,
"fake_scribble": unload_hed,
"mlsd": unload_mlsd,
"depth": unload_midas,
"depth_leres": unload_leres,
"normal_map": unload_midas,
"pidinet": unload_pidinet,
"openpose": unload_openpose,
"openpose_hand": unload_openpose,
"segmentation": unload_uniformer,
}
deforum_latest_model_hash = ""
def restore_networks(unet):
global deforum_latest_network
global deforum_latest_params
if deforum_latest_network is not None:
print("restoring last networks")
deforum_input_image = None
deforum_latest_network.restore(unet)
deforum_latest_network = None
last_module = deforum_latest_params[0]
if last_module is not None:
unloadable.get(last_module, lambda:None)()
def process(p, *args):
global deforum_latest_network
global deforum_latest_params
global deforum_input_image
global deforum_latest_model_hash
unet = p.sd_model.model.diffusion_model
enabled, module, model, weight, image, scribble_mode, \
resize_mode, rgbbgr_mode, lowvram, pres, pthr_a, pthr_b, guidance_strength = args
if not enabled:
restore_networks(unet)
return
models_changed = deforum_latest_params[1] != model \
or deforum_latest_model_hash != p.sd_model.sd_model_hash or deforum_latest_network == None \
or (deforum_latest_network is not None and deforum_latest_network.lowvram != lowvram)
deforum_latest_params = (module, model)
deforum_latest_model_hash = p.sd_model.sd_model_hash
if models_changed:
restore_networks(unet)
model_path = cn_models.get(model, None)
if model_path is None:
raise RuntimeError(f"model not found: {model}")
# trim '"' at start/end
if model_path.startswith("\"") and model_path.endswith("\""):
model_path = model_path[1:-1]
if not os.path.exists(model_path):
raise ValueError(f"file not found: {model_path}")
print(f"Loading preprocessor: {module}, model: {model}")
state_dict = load_state_dict(model_path)
network_module = PlugableControlModel
network_config = shared.opts.data.get("control_net_model_config", default_conf)
if any([k.startswith("body.") for k, v in state_dict.items()]):
# adapter model
network_module = PlugableAdapter
network_config = shared.opts.data.get("control_net_model_adapter_config", default_conf_adapter)
network = network_module(
state_dict=state_dict,
config_path=network_config,
weight=weight,
lowvram=lowvram,
base_model=unet,
)
network.to(p.sd_model.device, dtype=p.sd_model.dtype)
network.hook(unet, p.sd_model)
print(f"ControlNet model {model} loaded.")
deforum_latest_network = network
if image is not None:
deforum_input_image = HWC3(image['image'])
if 'mask' in image and image['mask'] is not None and not ((image['mask'][:, :, 0]==0).all() or (image['mask'][:, :, 0]==255).all()):
print("using mask as input")
deforum_input_image = HWC3(image['mask'][:, :, 0])
scribble_mode = True
else:
# use img2img init_image as default
deforum_input_image = getattr(p, "init_images", [None])[0]
if deforum_input_image is None:
raise ValueError('controlnet is enabled but no input image is given')
deforum_input_image = HWC3(np.asarray(deforum_input_image))
if scribble_mode:
detected_map = np.zeros_like(deforum_input_image, dtype=np.uint8)
detected_map[np.min(deforum_input_image, axis=2) < 127] = 255
deforum_input_image = detected_map
from scripts.processor import canny, midas, midas_normal, leres, hed, mlsd, openpose, pidinet, simple_scribble, fake_scribble, uniformer
preprocessor = {
"none": lambda x, *args, **kwargs: x,
"canny": canny,
"depth": midas,
"depth_leres": leres,
"hed": hed,
"mlsd": mlsd,
"normal_map": midas_normal,
"openpose": openpose,
# "openpose_hand": openpose_hand,
"pidinet": pidinet,
"scribble": simple_scribble,
"fake_scribble": fake_scribble,
"segmentation": uniformer,
}
preprocessor = preprocessor[deforum_latest_params[0]]
h, w, bsz = p.height, p.width, p.batch_size
if pres > 64:
detected_map = preprocessor(deforum_input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b)
else:
detected_map = preprocessor(deforum_input_image)
detected_map = HWC3(detected_map)
if module == "normal_map" or rgbbgr_mode:
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0
else:
control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0
control = rearrange(control, 'h w c -> c h w')
detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w')
if resize_mode == "Scale to Fit (Inner Fit)":
transform = Compose([
Resize(h if h<w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w))
])
control = transform(control)
detected_map = transform(detected_map)
elif resize_mode == "Envelope (Outer Fit)":
transform = Compose([
Resize(h if h>w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w))
])
control = transform(control)
detected_map = transform(detected_map)
else:
control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control)
detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map)
# for log use
detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8)
# control = torch.stack([control for _ in range(bsz)], dim=0)
deforum_latest_network.notify(control, weight, guidance_strength)
if shared.opts.data.get("control_net_skip_img2img_processing") and hasattr(p, "init_images"):
swap_img2img_pipeline(p)
def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img):
p.__class__ = processing.StableDiffusionProcessingTxt2Img
dummy = processing.StableDiffusionProcessingTxt2Img()
for k,v in dummy.__dict__.items():
if hasattr(p, k):
continue
setattr(p, k, v)