AnimateDiff 2.0.0-a ControlNet part (#2661)
* ad * ad * sparsectrl and keyframe, excluding ip-adapter emb interpolation * fix ruff * ipadapter prompt travel. now implementation has been finished, next step test every feature * sparsectrl works, ip-adapter still not working * everything but i2i batch is working properly * everything but i2i batch is working properly, fix ruffpull/2678/head
parent
75ac3803fe
commit
461f8c5f16
|
|
@ -19,6 +19,16 @@ class BatchHijack:
|
|||
self.postprocess_batch_callbacks = [self.on_postprocess_batch]
|
||||
|
||||
def img2img_process_batch_hijack(self, p, *args, **kwargs):
|
||||
try:
|
||||
from scripts.animatediff_utils import get_animatediff_arg
|
||||
ad_params = get_animatediff_arg(p)
|
||||
if ad_params and ad_params.enable:
|
||||
ad_params.is_i2i_batch = True
|
||||
from scripts.animatediff_i2ibatch import animatediff_i2i_batch
|
||||
return animatediff_i2i_batch(p, *args, **kwargs)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
cn_is_batch, batches, output_dir, _ = get_cn_batches(p)
|
||||
if not cn_is_batch:
|
||||
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
|
||||
|
|
@ -31,6 +41,14 @@ class BatchHijack:
|
|||
self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
|
||||
|
||||
def processing_process_images_hijack(self, p, *args, **kwargs):
|
||||
try:
|
||||
from scripts.animatediff_utils import get_animatediff_arg
|
||||
ad_params = get_animatediff_arg(p)
|
||||
if ad_params and ad_params.enable:
|
||||
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if self.is_batch:
|
||||
# we are in img2img batch tab, do a single batch iteration
|
||||
return self.process_images_cn_batch(p, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ class ImageEmbed(NamedTuple):
|
|||
"""Image embed for a single image."""
|
||||
cond_emb: torch.Tensor
|
||||
uncond_emb: torch.Tensor
|
||||
bypass_average: bool = False
|
||||
|
||||
def eval(self, cond_mark: torch.Tensor) -> torch.Tensor:
|
||||
assert cond_mark.ndim == 4
|
||||
assert self.cond_emb.ndim == self.uncond_emb.ndim == 3
|
||||
assert self.cond_emb.shape[0] == self.uncond_emb.shape[0] == 1
|
||||
assert self.uncond_emb.shape[0] == 1 or self.cond_emb.shape[0] == self.uncond_emb.shape[0]
|
||||
assert self.cond_emb.shape[0] == 1 or self.cond_emb.shape[0] == cond_mark.shape[0]
|
||||
cond_mark = cond_mark[:, :, :, 0].to(self.cond_emb)
|
||||
device = cond_mark.device
|
||||
dtype = cond_mark.dtype
|
||||
|
|
@ -26,7 +28,7 @@ class ImageEmbed(NamedTuple):
|
|||
)
|
||||
|
||||
def average_of(*args: List[Tuple[torch.Tensor, torch.Tensor]]) -> "ImageEmbed":
|
||||
conds, unconds = zip(*args)
|
||||
conds, unconds, _ = zip(*args)
|
||||
def average_tensors(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.sum(torch.stack(tensors), dim=0) / len(tensors)
|
||||
return ImageEmbed(average_tensors(conds), average_tensors(unconds))
|
||||
|
|
@ -603,11 +605,14 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
self.dtype = dtype
|
||||
|
||||
self.ipadapter.to(device, dtype=self.dtype)
|
||||
if isinstance(preprocessor_outputs, (list, tuple)):
|
||||
preprocessor_outputs = preprocessor_outputs
|
||||
if getattr(preprocessor_outputs, "bypass_average", False):
|
||||
self.image_emb = preprocessor_outputs
|
||||
else:
|
||||
preprocessor_outputs = [preprocessor_outputs]
|
||||
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
|
||||
if isinstance(preprocessor_outputs, (list, tuple)):
|
||||
preprocessor_outputs = preprocessor_outputs
|
||||
else:
|
||||
preprocessor_outputs = [preprocessor_outputs]
|
||||
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
|
||||
# From https://github.com/laksjdjf/IPAdapter-ComfyUI
|
||||
if not self.sdxl:
|
||||
number = 0 # index of to_kvs
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import tracemalloc
|
|||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
from typing import Dict, Optional, Tuple, List, Union
|
||||
import modules.scripts as scripts
|
||||
from modules import shared, devices, script_callbacks, processing, masking, images
|
||||
|
|
@ -16,7 +16,7 @@ from scripts import global_state, hook, external_code, batch_hijack, controlnet_
|
|||
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
|
||||
from scripts.processor import HWC3, preprocessor_sliders_config
|
||||
from scripts.controlnet_lllite import clear_all_lllite
|
||||
from scripts.controlmodel_ipadapter import clear_all_ip_adapter
|
||||
from scripts.controlmodel_ipadapter import ImageEmbed, clear_all_ip_adapter
|
||||
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
|
||||
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
|
||||
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
|
||||
|
|
@ -116,7 +116,7 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
|||
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
||||
elif isinstance(image['mask'], str):
|
||||
if os.path.exists(image['mask']):
|
||||
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
|
||||
image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8')
|
||||
elif image['mask']:
|
||||
image['mask'] = external_code.to_base64_nparray(image['mask'])
|
||||
else:
|
||||
|
|
@ -614,6 +614,40 @@ class Script(scripts.Script, metaclass=(
|
|||
# 4 input image sources.
|
||||
p_image_control = getattr(p, "image_control", None)
|
||||
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
|
||||
|
||||
# AnimateDiff + ControlNet batch processing.
|
||||
unit_is_ad_batch = getattr(unit, "animatediff_batch", False)
|
||||
if unit_is_ad_batch:
|
||||
batch_parameters = unit.batch_images.split("\n")
|
||||
batch_image_dir = batch_parameters[0]
|
||||
logger.info(f"AnimateDiff + ControlNet {unit.module} receive the following parameters:")
|
||||
logger.info(f"\tbatch control images: {batch_image_dir}")
|
||||
for ad_cn_batch_parameter in batch_parameters[1:]:
|
||||
if ad_cn_batch_parameter.startswith("mask:"):
|
||||
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:"):].strip()
|
||||
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
|
||||
elif ad_cn_batch_parameter.startswith("keyframe:"):
|
||||
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:"):].strip()
|
||||
unit.batch_keyframe_idx = [int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(',')]
|
||||
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
|
||||
batch_image_files = shared.listfiles(batch_image_dir)
|
||||
for batch_modifier in getattr(unit, 'batch_modifiers', []):
|
||||
batch_image_files = batch_modifier(batch_image_files, p)
|
||||
unit.batch_image_files = batch_image_files
|
||||
unit.image = []
|
||||
for idx, image_path in enumerate(batch_image_files):
|
||||
mask_path = None
|
||||
if getattr(unit, "batch_mask_dir", None) is not None:
|
||||
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
|
||||
if len(batch_mask_files) >= len(batch_image_files):
|
||||
mask_path = batch_mask_files[idx]
|
||||
else:
|
||||
mask_path = batch_mask_files[0]
|
||||
unit.image.append({
|
||||
"image": image_path,
|
||||
"mask": mask_path,
|
||||
})
|
||||
|
||||
image = parse_unit_image(unit)
|
||||
a1111_image = getattr(p, "init_images", [None])[0]
|
||||
|
||||
|
|
@ -635,6 +669,14 @@ class Script(scripts.Script, metaclass=(
|
|||
# Add mask logic if later there is a processor that accepts mask
|
||||
# on multiple inputs.
|
||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
||||
if unit_is_ad_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
|
||||
for idx in range(len(input_image)):
|
||||
while len(image[idx]['mask'].shape) < 3:
|
||||
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
|
||||
if 'inpaint' in unit.module:
|
||||
color = HWC3(image[idx]["image"])
|
||||
alpha = image[idx]['mask'][:, :, 0:1]
|
||||
input_image[idx] = np.concatenate([color, alpha], axis=2)
|
||||
else:
|
||||
input_image = HWC3(decode_image(image['image']))
|
||||
if 'mask' in image and image['mask'] is not None:
|
||||
|
|
@ -894,14 +936,21 @@ class Script(scripts.Script, metaclass=(
|
|||
model_net, control_model_type = Script.load_control_model(p, unet, unit.model)
|
||||
model_net.reset()
|
||||
|
||||
if model_net is not None and getattr(devices, "fp8", False) and control_model_type == ControlModelType.ControlNet:
|
||||
for _module in model_net.modules(): # FIXME: let's only apply fp8 to ControlNet for now
|
||||
if isinstance(_module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
_module.to(torch.float8_e4m3fn)
|
||||
|
||||
if control_model_type == ControlModelType.ControlLoRA:
|
||||
control_lora = model_net.control_model
|
||||
bind_control_lora(unet, control_lora)
|
||||
p.controlnet_control_loras.append(control_lora)
|
||||
|
||||
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
|
||||
is_cn_ad_batch = getattr(unit, "animatediff_batch", False)
|
||||
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
|
||||
if isinstance(input_image, list):
|
||||
assert unit.accepts_multiple_inputs()
|
||||
assert unit.accepts_multiple_inputs() or is_cn_ad_batch
|
||||
input_images = input_image
|
||||
else: # Following operations are only for single input image.
|
||||
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
|
||||
|
|
@ -909,14 +958,15 @@ class Script(scripts.Script, metaclass=(
|
|||
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
# inpaint_only+lama is special and required outpaint fix
|
||||
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
input_image,
|
||||
target_H=h,
|
||||
target_W=w,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
input_images = [input_image]
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
input_images[0],
|
||||
target_H=h,
|
||||
target_W=w,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
# Preprocessor result may depend on numpy random operations, use the
|
||||
# random seed in `StableDiffusionProcessing` to make the
|
||||
# preprocessor result reproducable.
|
||||
|
|
@ -964,12 +1014,88 @@ class Script(scripts.Script, metaclass=(
|
|||
|
||||
if control_model_type == ControlModelType.ReVision:
|
||||
control = control['image_embeds']
|
||||
|
||||
if is_image and is_cn_ad_batch: # AnimateDiff save VRAM
|
||||
control = control.cpu()
|
||||
if hr_control is not None:
|
||||
hr_control = hr_control.cpu()
|
||||
|
||||
return control, hr_control
|
||||
|
||||
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images]))
|
||||
if len(controls) == len(hr_controls) == 1:
|
||||
def optional_tqdm(iterable, use_tqdm=is_cn_ad_batch):
|
||||
from tqdm import tqdm
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
|
||||
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
|
||||
control = controls[0]
|
||||
hr_control = hr_controls[0]
|
||||
elif is_cn_ad_batch or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
|
||||
if unit.accepts_multiple_inputs():
|
||||
ip_adapter_image_emb_cond = []
|
||||
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
|
||||
for c in cc:
|
||||
c = model_net.get_image_emb(c) # noqa
|
||||
ip_adapter_image_emb_cond.append(c.cond_emb)
|
||||
c_cond = torch.cat(ip_adapter_image_emb_cond, dim=0)
|
||||
c = ImageEmbed(c_cond, c.uncond_emb, True)
|
||||
else:
|
||||
c = torch.cat(cc, dim=0)
|
||||
# SparseCtrl keyframe need to encode control image with VAE
|
||||
if control_model_type == ControlModelType.SparseCtrl and \
|
||||
model_net.control_model.use_simplified_condition_embedding: # noqa
|
||||
c = UnetHook.call_vae_using_process(p, c)
|
||||
# handle key frame control for different control methods
|
||||
if cn_ad_keyframe_idx is not None or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
if control_model_type == ControlModelType.SparseCtrl:
|
||||
# sparsectrl has its own embed generator
|
||||
from scripts.controlnet_sparsectrl import SparseCtrl
|
||||
if cn_ad_keyframe_idx is None:
|
||||
cn_ad_keyframe_idx = [0]
|
||||
logger.info(f"SparseCtrl: control images will be applied to frames: {cn_ad_keyframe_idx}")
|
||||
else:
|
||||
logger.info(f"SparseCtrl: control images will be applied to frames: {cn_ad_keyframe_idx}")
|
||||
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
|
||||
logger.info(f"\t{frame_idx}: {frame_path}")
|
||||
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
|
||||
elif unit.accepts_multiple_inputs():
|
||||
# ip-adapter should do prompt travel
|
||||
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
|
||||
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
|
||||
logger.info(f"\t{frame_idx}: {frame_path}")
|
||||
from scripts.animatediff_utils import get_animatediff_arg
|
||||
ip_adapter_emb = c
|
||||
c = c.cond_emb
|
||||
c_full = torch.zeros((p.batch_size, *c.shape[1:]), dtype=c.dtype, device=c.device)
|
||||
for i, idx in enumerate(cn_ad_keyframe_idx[:-1]):
|
||||
c_full[idx:cn_ad_keyframe_idx[i + 1]] = c[i]
|
||||
c_full[cn_ad_keyframe_idx[-1]:] = c[-1]
|
||||
ad_params = get_animatediff_arg(p)
|
||||
prompt_scheduler = deepcopy(ad_params.prompt_scheduler)
|
||||
prompt_scheduler.prompt_map = {i: "" for i in cn_ad_keyframe_idx}
|
||||
prompt_closed_loop = (ad_params.video_length > ad_params.batch_size) and (ad_params.closed_loop in ['R+P', 'A'])
|
||||
c_full = prompt_scheduler.multi_cond(c_full, prompt_closed_loop)
|
||||
if shared.opts.batch_cond_uncond:
|
||||
c_full = torch.cat([c_full, c_full], dim=0)
|
||||
c = ImageEmbed(c_full, ip_adapter_emb.uncond_emb, True)
|
||||
else:
|
||||
# normal CN should insert empty frames
|
||||
logger.info(f"ControlNet: control images will be applied to frames: {cn_ad_keyframe_idx} where")
|
||||
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
|
||||
logger.info(f"\t{frame_idx}: {frame_path}")
|
||||
c_full = torch.zeros((p.batch_size, *c.shape[1:]), dtype=c.dtype, device=c.device)
|
||||
c_full[cn_ad_keyframe_idx] = c
|
||||
c = c_full
|
||||
# handle batch condition and unconditional
|
||||
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs():
|
||||
c = torch.cat([c, c], dim=0)
|
||||
return c
|
||||
|
||||
control = ad_process_control(controls)
|
||||
hr_control = ad_process_control(hr_controls) if hr_controls[0] is not None else None
|
||||
if control_model_type == ControlModelType.SparseCtrl:
|
||||
control_model_type = ControlModelType.ControlNet
|
||||
else:
|
||||
control = controls
|
||||
hr_control = hr_controls
|
||||
|
|
@ -1008,22 +1134,24 @@ class Script(scripts.Script, metaclass=(
|
|||
final_inpaint_feed = hr_control if hr_control is not None else control
|
||||
final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
|
||||
final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
|
||||
final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
|
||||
final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
|
||||
final_inpaint_mask = final_inpaint_feed[:, 3, :, :].astype(np.float32)
|
||||
final_inpaint_raw = final_inpaint_feed[:, :3].astype(np.float32)
|
||||
sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
|
||||
final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
|
||||
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
|
||||
_, Hmask, Wmask = final_inpaint_mask.shape
|
||||
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[:, None]
|
||||
_, _, Hmask, Wmask = final_inpaint_mask.shape
|
||||
final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
|
||||
final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
|
||||
|
||||
def inpaint_only_post_processing(x):
|
||||
def inpaint_only_post_processing(x, i):
|
||||
if i >= final_inpaint_raw.shape[0]:
|
||||
i = 0
|
||||
_, H, W = x.shape
|
||||
if Hmask != H or Wmask != W:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
return x
|
||||
r = final_inpaint_raw.to(x.dtype).to(x.device)
|
||||
m = final_inpaint_mask.to(x.dtype).to(x.device)
|
||||
r = final_inpaint_raw[i].to(x.dtype).to(x.device)
|
||||
m = final_inpaint_mask[i].to(x.dtype).to(x.device)
|
||||
y = m * x.clip(0, 1) + (1 - m) * r
|
||||
y = y.clip(0, 1)
|
||||
return y
|
||||
|
|
@ -1034,13 +1162,15 @@ class Script(scripts.Script, metaclass=(
|
|||
final_feed = hr_control if hr_control is not None else control
|
||||
final_feed = final_feed.detach().cpu().numpy()
|
||||
final_feed = np.ascontiguousarray(final_feed).copy()
|
||||
final_feed = final_feed[0, 0, :, :].astype(np.float32)
|
||||
final_feed = final_feed[:, 0, :, :].astype(np.float32)
|
||||
final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
|
||||
Hfeed, Wfeed = final_feed.shape
|
||||
_, Hfeed, Wfeed = final_feed.shape
|
||||
|
||||
if 'luminance' in unit.module:
|
||||
|
||||
def recolor_luminance_post_processing(x):
|
||||
def recolor_luminance_post_processing(x, i):
|
||||
if i >= final_feed.shape[0]:
|
||||
i = 0
|
||||
C, H, W = x.shape
|
||||
if Hfeed != H or Wfeed != W or C != 3:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
|
|
@ -1048,7 +1178,7 @@ class Script(scripts.Script, metaclass=(
|
|||
h = x.detach().cpu().numpy().transpose((1, 2, 0))
|
||||
h = (h * 255).clip(0, 255).astype(np.uint8)
|
||||
h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
|
||||
h[:, :, 0] = final_feed
|
||||
h[:, :, 0] = final_feed[i]
|
||||
h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
|
||||
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
|
||||
y = torch.from_numpy(h).clip(0, 1).to(x)
|
||||
|
|
@ -1058,7 +1188,9 @@ class Script(scripts.Script, metaclass=(
|
|||
|
||||
if 'intensity' in unit.module:
|
||||
|
||||
def recolor_intensity_post_processing(x):
|
||||
def recolor_intensity_post_processing(x, i):
|
||||
if i >= final_feed.shape[0]:
|
||||
i = 0
|
||||
C, H, W = x.shape
|
||||
if Hfeed != H or Wfeed != W or C != 3:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
|
|
@ -1066,7 +1198,7 @@ class Script(scripts.Script, metaclass=(
|
|||
h = x.detach().cpu().numpy().transpose((1, 2, 0))
|
||||
h = (h * 255).clip(0, 255).astype(np.uint8)
|
||||
h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
|
||||
h[:, :, 2] = final_feed
|
||||
h[:, :, 2] = final_feed[i]
|
||||
h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
|
||||
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
|
||||
y = torch.from_numpy(h).clip(0, 1).to(x)
|
||||
|
|
@ -1157,7 +1289,7 @@ class Script(scripts.Script, metaclass=(
|
|||
images = kwargs.get('images', [])
|
||||
for post_processor in self.post_processors:
|
||||
for i in range(len(images)):
|
||||
images[i] = post_processor(images[i])
|
||||
images[i] = post_processor(images[i], i)
|
||||
return
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from modules import devices
|
|||
from scripts.adapter import PlugableAdapter, Adapter, StyleAdapter, Adapter_light
|
||||
from scripts.controlnet_lllite import PlugableControlLLLite
|
||||
from scripts.cldm import PlugableControlModel
|
||||
from scripts.controlnet_sparsectrl import PlugableSparseCtrlModel
|
||||
from scripts.controlmodel_ipadapter import PlugableIPAdapter
|
||||
from scripts.logging import logger
|
||||
from scripts.controlnet_diffusers import convert_from_diffuser_state_dict
|
||||
|
|
@ -132,6 +133,21 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
|
|||
network.to(devices.dtype_unet)
|
||||
return ControlModel(network, ControlModelType.ControlLoRA)
|
||||
|
||||
if "down_blocks.0.motion_modules.0.temporal_transformer.norm.weight" in state_dict: # sparsectrl
|
||||
config = copy.deepcopy(controlnet_default_config)
|
||||
if "input_hint_block.0.weight" in state_dict: # rgb
|
||||
config['use_simplified_condition_embedding'] = True
|
||||
config['conditioning_channels'] = 5
|
||||
else: # scribble
|
||||
config['use_simplified_condition_embedding'] = False
|
||||
config['conditioning_channels'] = 4
|
||||
|
||||
config['use_fp16'] = devices.dtype_unet == torch.float16
|
||||
|
||||
network = PlugableSparseCtrlModel(config, state_dict)
|
||||
network.to(devices.dtype_unet)
|
||||
return ControlModel(network, ControlModelType.SparseCtrl)
|
||||
|
||||
if "controlnet_cond_embedding.conv_in.weight" in state_dict: # diffusers
|
||||
state_dict = convert_from_diffuser_state_dict(state_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
from typing import Tuple, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from scripts.cldm import PlugableControlModel, ControlNet, zero_module, conv_nd, TimestepEmbedSequential
|
||||
|
||||
class PlugableSparseCtrlModel(PlugableControlModel):
|
||||
def __init__(self, config, state_dict=None):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.control_model = SparseCtrl(**self.config).cpu()
|
||||
if state_dict is not None:
|
||||
self.control_model.load_state_dict(state_dict, strict=False)
|
||||
self.gpu_component = None
|
||||
|
||||
|
||||
class CondEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = conv_nd(dims, conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(conv_nd(dims, channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(conv_nd(dims, channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(conv_nd(dims, block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseCtrl(ControlNet):
|
||||
def __init__(self, use_simplified_condition_embedding=True, conditioning_channels=4, **kwargs):
|
||||
super().__init__(hint_channels=1, **kwargs) # we don't need hint_channels, but we need to set it to 1 to avoid errors
|
||||
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
||||
if use_simplified_condition_embedding:
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
zero_module(conv_nd(self.dims, conditioning_channels, kwargs.get("model_channels", 320), kernel_size=3, padding=1)))
|
||||
else:
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
CondEmbed(
|
||||
self.dims, kwargs.get("model_channels", 320),
|
||||
conditioning_channels=conditioning_channels,))
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=False):
|
||||
mm_dict = {}
|
||||
cn_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if "motion_modules" in k:
|
||||
mm_dict[k] = v
|
||||
else:
|
||||
cn_dict[k] = v
|
||||
|
||||
super().load_state_dict(cn_dict, strict=True)
|
||||
|
||||
from scripts.animatediff_mm import MotionWrapper, MotionModuleType
|
||||
sparsectrl_mm = MotionWrapper("", "", MotionModuleType.SparseCtrl)
|
||||
sparsectrl_mm.load_state_dict(mm_dict, strict=True)
|
||||
|
||||
for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
|
||||
mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
|
||||
mm_inject = getattr(sparsectrl_mm.down_blocks[mm_idx0], "motion_modules")[mm_idx1]
|
||||
self.input_blocks[unet_idx].append(mm_inject)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_cond_mask(control_image_index: List[int], control_image_latents: torch.Tensor, video_length: int):
|
||||
hint_cond = torch.zeros((video_length, *control_image_latents.shape[1:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
|
||||
hint_cond[control_image_index] = control_image_latents[:len(control_image_index)]
|
||||
hint_cond_mask = torch.zeros((hint_cond.shape[0], 1, *hint_cond.shape[2:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
|
||||
hint_cond_mask[control_image_index] = 1.0
|
||||
return torch.cat([hint_cond, hint_cond_mask], dim=1)
|
||||
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
return super().forward(torch.zeros_like(x, device=x.device), hint, timesteps, context, y=y, **kwargs)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from scripts.logging import logger
|
||||
|
||||
version_flag = 'v1.1.440'
|
||||
version_flag = 'v1.1.441'
|
||||
|
||||
logger.info(f"ControlNet {version_flag}")
|
||||
# A smart trick to know if user has updated as well as if user has restarted terminal.
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class ControlModelType(Enum):
|
|||
IPAdapter = "IPAdapter, Hu Ye"
|
||||
Controlllite = "Controlllite, Kohya"
|
||||
InstantID = "InstantID, Qixun Wang"
|
||||
SparseCtrl = "SparseCtrl, Yuwei Guo"
|
||||
|
||||
def is_controlnet(self) -> bool:
|
||||
"""Returns whether the control model should be treated as ControlNet."""
|
||||
|
|
|
|||
|
|
@ -319,6 +319,10 @@ def select_control_type(
|
|||
filtered_preprocessor_list += [
|
||||
x for x in preprocessor_list if "invert" in x.lower()
|
||||
]
|
||||
if pattern in ["sparsectrl"]:
|
||||
filtered_preprocessor_list += [
|
||||
x for x in preprocessor_list if "scribble" in x.lower()
|
||||
]
|
||||
filtered_model_list = [
|
||||
model for model in all_models
|
||||
if model.lower() == "none" or
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|||
from scripts.logging import logger
|
||||
from scripts.enums import ControlModelType, AutoMachine, HiResFixOption
|
||||
from scripts.controlmodel_ipadapter import ImageEmbed
|
||||
from scripts.controlnet_sparsectrl import SparseCtrl
|
||||
from modules import devices, lowvram, shared, scripts
|
||||
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding, make_beta_schedule
|
||||
|
|
@ -384,8 +385,12 @@ class UnetHook(nn.Module):
|
|||
vae_output = vae_cache.get(x)
|
||||
if vae_output is None:
|
||||
with devices.autocast():
|
||||
vae_output = p.sd_model.encode_first_stage(x)
|
||||
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
|
||||
vae_output = torch.stack([
|
||||
p.sd_model.get_first_stage_encoding(
|
||||
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
|
||||
)[0].to(img.device)
|
||||
for img in x
|
||||
])
|
||||
if torch.all(torch.isnan(vae_output)).item():
|
||||
logger.info('ControlNet find Nans in the VAE encoding. \n '
|
||||
'Now ControlNet will automatically retry.\n '
|
||||
|
|
@ -393,8 +398,12 @@ class UnetHook(nn.Module):
|
|||
devices.dtype_vae = torch.float32
|
||||
x = x.to(devices.dtype_vae)
|
||||
p.sd_model.first_stage_model.to(devices.dtype_vae)
|
||||
vae_output = p.sd_model.encode_first_stage(x)
|
||||
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
|
||||
vae_output = torch.stack([
|
||||
p.sd_model.get_first_stage_encoding(
|
||||
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
|
||||
)[0].to(img.device)
|
||||
for img in x
|
||||
])
|
||||
vae_cache.set(x, vae_output)
|
||||
logger.info(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
|
||||
latent = vae_output
|
||||
|
|
@ -571,7 +580,7 @@ class UnetHook(nn.Module):
|
|||
controlnet_context = context
|
||||
|
||||
# ControlNet inpaint protocol
|
||||
if hint.shape[1] == 4:
|
||||
if hint.shape[1] == 4 and not isinstance(control_model, SparseCtrl):
|
||||
c = hint[:, 0:3, :, :]
|
||||
m = hint[:, 3:4, :, :]
|
||||
m = (m > 0.5).float()
|
||||
|
|
|
|||
|
|
@ -1316,6 +1316,7 @@ preprocessor_filters = {
|
|||
"T2I-Adapter": "none",
|
||||
"IP-Adapter": "ip-adapter_clip_sd15",
|
||||
"Instant_ID": "instant_id",
|
||||
"SparseCtrl": "none",
|
||||
}
|
||||
|
||||
preprocessor_filters_aliases = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue