sd-webui-animatediff/scripts/animatediff_cn.py

616 lines
31 KiB
Python

from pathlib import Path
import os
import shutil
import cv2
import numpy as np
import torch
from modules import processing, shared
from modules.paths import data_path
from modules.processing import (StableDiffusionProcessing,
StableDiffusionProcessingImg2Img,
StableDiffusionProcessingTxt2Img)
from scripts.animatediff_logger import logger_animatediff as logger
from scripts.animatediff_ui import AnimateDiffProcess
class AnimateDiffControl:
def __init__(self, p: StableDiffusionProcessing):
self.original_processing_process_images_hijack = None
self.original_controlnet_main_entry = None
self.original_postprocess_batch = None
try:
from scripts.external_code import find_cn_script
self.cn_script = find_cn_script(p.scripts)
except:
self.cn_script = None
def hack_batchhijack(self, params: AnimateDiffProcess):
cn_script = self.cn_script
def get_input_frames():
if params.video_source is not None and params.video_source != '':
cap = cv2.VideoCapture(params.video_source)
frame_count = 0
tmp_frame_dir = Path(f'{data_path}/tmp/animatediff-frames/')
tmp_frame_dir.mkdir(parents=True, exist_ok=True)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.imwrite(f"{tmp_frame_dir}/{frame_count}.png", frame)
frame_count += 1
cap.release()
return str(tmp_frame_dir)
elif params.video_path is not None and params.video_path != '':
return params.video_path
return ''
from scripts import external_code
from scripts.batch_hijack import InputMode, BatchHijack, instance
def hacked_processing_process_images_hijack(self, p, *args, **kwargs):
if self.is_batch:
# we are in img2img batch tab, do a single batch iteration
return self.process_images_cn_batch(p, *args, **kwargs)
units = external_code.get_all_units_in_processing(p)
units = [unit for unit in units if getattr(unit, 'enabled', False)]
if len(units) == 0:
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
global_input_frames = get_input_frames()
for idx, unit in enumerate(units):
# if no input given for this unit, use global input
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
if not (isinstance(unit.batch_images, str) and unit.batch_images != ''):
assert global_input_frames != '', 'No input images found for ControlNet module'
unit.batch_images = global_input_frames
elif unit.image is None:
try:
cn_script.choose_input_image(p, unit, idx)
except:
assert global_input_frames != '', 'No input images found for ControlNet module'
unit.batch_images = global_input_frames
unit.input_mode = InputMode.BATCH
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
if 'inpaint' in unit.module:
images = shared.listfiles(f'{unit.batch_images}/image')
masks = shared.listfiles(f'{unit.batch_images}/mask')
assert len(images) == len(masks), 'Inpainting image mask count mismatch'
unit.batch_images = [{'image': images[i], 'mask': masks[i]} for i in range(len(images))]
else:
unit.batch_images = shared.listfiles(unit.batch_images)
unit_batch_list = [len(unit.batch_images) for unit in units
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH]
if len(unit_batch_list) == 0:
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
video_length = min(unit_batch_list)
# ensure that params.video_length <= video_length and params.batch_size <= video_length
if params.video_length > video_length:
params.video_length = video_length
if params.batch_size > video_length:
params.batch_size = video_length
if params.video_default:
params.video_length = video_length
p.batch_size = video_length
for unit in units:
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
unit.batch_images = unit.batch_images[:params.video_length]
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
self.original_processing_process_images_hijack = BatchHijack.processing_process_images_hijack
BatchHijack.processing_process_images_hijack = hacked_processing_process_images_hijack
processing.process_images_inner = instance.processing_process_images_hijack
def restore_batchhijack(self):
from scripts.batch_hijack import BatchHijack, instance
BatchHijack.processing_process_images_hijack = self.original_processing_process_images_hijack
self.original_processing_process_images_hijack = None
processing.process_images_inner = instance.processing_process_images_hijack
def hack_cn(self):
cn_script = self.cn_script
from types import MethodType
from typing import Optional
from modules import images, masking
from PIL import Image, ImageFilter, ImageOps
from scripts import external_code, global_state, hook
# from scripts.controlnet_lora import bind_control_lora # do not support control lora for sdxl
from scripts.adapter import Adapter, Adapter_light, StyleAdapter
from scripts.batch_hijack import InputMode
# from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite # do not support controlllite for sdxl
from scripts.controlmodel_ipadapter import (PlugableIPAdapter,
clear_all_ip_adapter)
from scripts.hook import ControlModelType, ControlParams, UnetHook
from scripts.logging import logger
from scripts.processor import model_free_preprocessors
def hacked_main_entry(self, p: StableDiffusionProcessing):
def image_has_mask(input_image: np.ndarray) -> bool:
return (
input_image.ndim == 3 and
input_image.shape[2] == 4 and
np.max(input_image[:, :, 3]) > 127
)
def prepare_mask(
mask: Image.Image, p: processing.StableDiffusionProcessing
) -> Image.Image:
mask = mask.convert("L")
if getattr(p, "inpainting_mask_invert", False):
mask = ImageOps.invert(mask)
if hasattr(p, 'mask_blur_x'):
if getattr(p, "mask_blur_x", 0) > 0:
np_mask = np.array(mask)
kernel_size = 2 * int(2.5 * p.mask_blur_x + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), p.mask_blur_x)
mask = Image.fromarray(np_mask)
if getattr(p, "mask_blur_y", 0) > 0:
np_mask = np.array(mask)
kernel_size = 2 * int(2.5 * p.mask_blur_y + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), p.mask_blur_y)
mask = Image.fromarray(np_mask)
else:
if getattr(p, "mask_blur", 0) > 0:
mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
return mask
def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
try:
tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0))
tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0))
seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF
np.random.seed(seed)
return seed
except Exception as e:
logger.warning(e)
logger.warning('Warning: Failed to use consistent random seed.')
return None
sd_ldm = p.sd_model
unet = sd_ldm.model.diffusion_model
self.noise_modifier = None
# setattr(p, 'controlnet_control_loras', []) # do not support control lora for sdxl
if self.latest_network is not None:
# always restore (~0.05s)
self.latest_network.restore()
# always clear (~0.05s)
# clear_all_lllite() # do not support controlllite for sdxl
clear_all_ip_adapter()
self.enabled_units = cn_script.get_enabled_units(p)
if len(self.enabled_units) == 0:
self.latest_network = None
return
detected_maps = []
forward_params = []
post_processors = []
# cache stuff
if self.latest_model_hash != p.sd_model.sd_model_hash:
cn_script.clear_control_model_cache()
for idx, unit in enumerate(self.enabled_units):
unit.module = global_state.get_module_basename(unit.module)
# unload unused preproc
module_list = [unit.module for unit in self.enabled_units]
for key in self.unloadable:
if key not in module_list:
self.unloadable.get(key, lambda:None)()
self.latest_model_hash = p.sd_model.sd_model_hash
for idx, unit in enumerate(self.enabled_units):
cn_script.bound_check_params(unit)
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
control_mode = external_code.control_mode_from_value(unit.control_mode)
if unit.module in model_free_preprocessors:
model_net = None
else:
model_net = cn_script.load_control_model(p, unet, unit.model)
model_net.reset()
# if getattr(model_net, 'is_control_lora', False): # do not support control lora for sdxl
# control_lora = model_net.control_model
# bind_control_lora(unet, control_lora)
# p.controlnet_control_loras.append(control_lora)
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
input_images = []
for img in unit.batch_images:
unit.image = img # TODO: SAM extension should use new API
input_image, _ = cn_script.choose_input_image(p, unit, idx)
input_images.append(input_image)
else:
input_image, image_from_a1111 = cn_script.choose_input_image(p, unit, idx)
input_images = [input_image]
if image_from_a1111:
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
if a1111_i2i_resize_mode is not None:
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
for idx, input_image in enumerate(input_images):
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
if a1111_mask.ndim == 2:
if a1111_mask.shape[0] == input_image.shape[0]:
if a1111_mask.shape[1] == input_image.shape[1]:
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
if a1111_i2i_resize_mode is not None:
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
if 'reference' not in unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) \
and p.inpaint_full_res and a1111_mask_image is not None:
logger.debug("A1111 inpaint mask START")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image]
mask = prepare_mask(a1111_mask_image, p)
crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)
input_image = [
images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
for i in input_image
]
input_image = [x.crop(crop_region) for x in input_image]
input_image = [
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
for x in input_image
]
input_image = [np.asarray(x)[:, :, 0] for x in input_image]
input_image = np.stack(input_image, axis=2)
logger.debug("A1111 inpaint mask END")
# safe numpy
logger.debug("Safe numpy convertion START")
input_image = np.ascontiguousarray(input_image.copy()).copy()
logger.debug("Safe numpy convertion END")
input_images[idx] = input_image
if 'inpaint_only' == unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) and p.image_mask is not None:
logger.warning('A1111 inpaint and ControlNet inpaint duplicated. ControlNet support enabled.')
unit.module = 'inpaint'
logger.info(f"Loading preprocessor: {unit.module}")
preprocessor = self.preprocessor[unit.module]
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
h = (p.height // 8) * 8
w = (p.width // 8) * 8
if high_res_fix:
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
hr_y = int(p.height * p.hr_scale)
hr_x = int(p.width * p.hr_scale)
else:
hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
hr_y = (hr_y // 8) * 8
hr_x = (hr_x // 8) * 8
else:
hr_y = h
hr_x = w
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
# inpaint_only+lama is special and required outpaint fix
for idx, input_image in enumerate(input_images):
_, input_image = cn_script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
input_images[idx] = input_image
control_model_type = ControlModelType.ControlNet
global_average_pooling = False
if 'reference' in unit.module:
control_model_type = ControlModelType.AttentionInjection
elif 'revision' in unit.module:
control_model_type = ControlModelType.ReVision
elif hasattr(model_net, 'control_model') and (isinstance(model_net.control_model, Adapter) or isinstance(model_net.control_model, Adapter_light)):
control_model_type = ControlModelType.T2I_Adapter
elif hasattr(model_net, 'control_model') and isinstance(model_net.control_model, StyleAdapter):
control_model_type = ControlModelType.T2I_StyleAdapter
elif isinstance(model_net, PlugableIPAdapter):
control_model_type = ControlModelType.IPAdapter
# elif isinstance(model_net, PlugableControlLLLite): # do not support controlllite for sdxl
# control_model_type = ControlModelType.Controlllite
if control_model_type is ControlModelType.ControlNet:
global_average_pooling = model_net.control_model.global_average_pooling
preprocessor_resolution = unit.processor_res
if unit.pixel_perfect:
preprocessor_resolution = external_code.pixel_perfect_resolution(
input_images[0],
target_H=h,
target_W=w,
resize_mode=resize_mode
)
logger.info(f'preprocessor resolution = {preprocessor_resolution}')
# Preprocessor result may depend on numpy random operations, use the
# random seed in `StableDiffusionProcessing` to make the
# preprocessor result reproducable.
# Currently following preprocessors use numpy random:
# - shuffle
seed = set_numpy_seed(p)
logger.debug(f"Use numpy seed {seed}.")
controls = []
hr_controls = []
controls_ipadapter = {'hidden_states': [], 'image_embeds': []}
hr_controls_ipadapter = {'hidden_states': [], 'image_embeds': []}
for idx, input_image in enumerate(input_images):
detected_map, is_image = preprocessor(
input_image,
res=preprocessor_resolution,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
)
if high_res_fix:
if is_image:
hr_control, hr_detected_map = cn_script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
detected_maps.append((hr_detected_map, unit.module))
else:
hr_control = detected_map
else:
hr_control = None
if is_image:
control, detected_map = cn_script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
detected_maps.append((detected_map, unit.module))
else:
control = detected_map
detected_maps.append((input_image, unit.module))
if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
if control_model_type == ControlModelType.IPAdapter:
if model_net.is_plus:
controls_ipadapter['hidden_states'].append(control['hidden_states'][-2])
else:
controls_ipadapter['image_embeds'].append(control['image_embeds'])
if hr_control is not None:
if model_net.is_plus:
hr_controls_ipadapter['hidden_states'].append(hr_control['hidden_states'][-2])
else:
hr_controls_ipadapter['image_embeds'].append(hr_control['image_embeds'])
else:
hr_controls_ipadapter = None
else:
controls.append(control)
if hr_control is not None:
hr_controls.append(hr_control)
else:
hr_controls = None
if control_model_type == ControlModelType.IPAdapter:
ipadapter_key = 'hidden_states' if model_net.is_plus else 'image_embeds'
controls = {ipadapter_key: torch.cat(controls_ipadapter[ipadapter_key], dim=0)}
if controls[ipadapter_key].shape[0] > 1:
controls[ipadapter_key] = torch.cat([controls[ipadapter_key], controls[ipadapter_key]], dim=0)
if model_net.is_plus:
controls[ipadapter_key] = [controls[ipadapter_key], None]
if hr_controls_ipadapter is not None:
hr_controls = {ipadapter_key: torch.cat(hr_controls_ipadapter[ipadapter_key], dim=0)}
if hr_controls[ipadapter_key].shape[0] > 1:
hr_controls[ipadapter_key] = torch.cat([hr_controls[ipadapter_key], hr_controls[ipadapter_key]], dim=0)
if model_net.is_plus:
hr_controls[ipadapter_key] = [hr_controls[ipadapter_key], None]
else:
controls = torch.cat(controls, dim=0)
if controls.shape[0] > 1:
controls = torch.cat([controls, controls], dim=0)
if hr_controls is not None:
hr_controls = torch.cat(hr_controls, dim=0)
if hr_controls.shape[0] > 1:
hr_controls = torch.cat([hr_controls, hr_controls], dim=0)
preprocessor_dict = dict(
name=unit.module,
preprocessor_resolution=preprocessor_resolution,
threshold_a=unit.threshold_a,
threshold_b=unit.threshold_b
)
forward_param = ControlParams(
control_model=model_net,
preprocessor=preprocessor_dict,
hint_cond=controls,
weight=unit.weight,
guidance_stopped=False,
start_guidance_percent=unit.guidance_start,
stop_guidance_percent=unit.guidance_end,
advanced_weighting=None,
control_model_type=control_model_type,
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_controls,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
)
forward_params.append(forward_param)
unit_is_batch = getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH
if 'inpaint_only' in unit.module:
final_inpaint_raws = []
final_inpaint_masks = []
for i in range(len(controls)):
final_inpaint_feed = hr_controls[i] if hr_controls is not None else controls[i]
final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
_, Hmask, Wmask = final_inpaint_mask.shape
final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
final_inpaint_raws.append(final_inpaint_raw)
final_inpaint_masks.append(final_inpaint_mask)
def inpaint_only_post_processing(x, i):
_, H, W = x.shape
if Hmask != H or Wmask != W:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
idx = i if unit_is_batch else 0
r = final_inpaint_raw[idx].to(x.dtype).to(x.device)
m = final_inpaint_mask[idx].to(x.dtype).to(x.device)
y = m * x.clip(0, 1) + (1 - m) * r
y = y.clip(0, 1)
return y
post_processors.append(inpaint_only_post_processing)
if 'recolor' in unit.module:
final_feeds = []
for i in range(len(controls)):
final_feed = hr_control if hr_control is not None else control
final_feed = final_feed.detach().cpu().numpy()
final_feed = np.ascontiguousarray(final_feed).copy()
final_feed = final_feed[0, 0, :, :].astype(np.float32)
final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
Hfeed, Wfeed = final_feed.shape
final_feeds.append(final_feed)
if 'luminance' in unit.module:
def recolor_luminance_post_processing(x, i):
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
h[:, :, 0] = final_feed[i if unit_is_batch else 0]
h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
return y
post_processors.append(recolor_luminance_post_processing)
if 'intensity' in unit.module:
def recolor_intensity_post_processing(x, i):
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
h[:, :, 2] = final_feed[i if unit_is_batch else 0]
h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
return y
post_processors.append(recolor_intensity_post_processing)
if '+lama' in unit.module:
forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
self.noise_modifier = forward_param.used_hint_cond_latent
del model_net
is_low_vram = any(unit.low_vram for unit in self.enabled_units)
self.latest_network = UnetHook(lowvram=is_low_vram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
for param in forward_params:
if param.control_model_type == ControlModelType.IPAdapter:
param.control_model.hook(
model=unet,
clip_vision_output=param.hint_cond,
weight=param.weight,
dtype=torch.float32,
start=param.start_guidance_percent,
end=param.stop_guidance_percent
)
if param.control_model_type == ControlModelType.Controlllite:
param.control_model.hook(
model=unet,
cond=param.hint_cond,
weight=param.weight,
start=param.start_guidance_percent,
end=param.stop_guidance_percent
)
self.detected_map = detected_maps
self.post_processors = post_processors
if os.path.exists(f'{data_path}/tmp/animatediff-frames/'):
shutil.rmtree(f'{data_path}/tmp/animatediff-frames/')
def hacked_postprocess_batch(self, p, *args, **kwargs):
images = kwargs.get('images', [])
for post_processor in self.post_processors:
for i in range(len(images)):
images[i] = post_processor(images[i], i)
return
self.original_controlnet_main_entry = self.cn_script.controlnet_main_entry
self.original_postprocess_batch = self.cn_script.postprocess_batch
self.cn_script.controlnet_main_entry = MethodType(hacked_main_entry, self.cn_script)
self.cn_script.postprocess_batch = MethodType(hacked_postprocess_batch, self.cn_script)
def restore_cn(self):
self.cn_script.controlnet_main_entry = self.original_controlnet_main_entry
self.original_controlnet_main_entry = None
self.cn_script.postprocess_batch = self.original_postprocess_batch
self.original_postprocess_batch = None
def hack(self, params: AnimateDiffProcess):
if self.cn_script is not None:
logger.info(f"Hacking ControlNet.")
self.hack_batchhijack(params)
self.hack_cn()
def restore(self):
if self.cn_script is not None:
logger.info(f"Restoring ControlNet.")
self.restore_batchhijack()
self.restore_cn()