AnimateDiff 2.0.0-a ControlNet part (#2661)

* ad

* ad

* sparsectrl and keyframe, excluding ip-adapter emb interpolation

* fix ruff

* ipadapter prompt travel. now implementation has been finished, next step test every feature

* sparsectrl works, ip-adapter still not working

* everything but i2i batch is working properly

* everything but i2i batch is working properly, fix ruff
pull/2678/head
Chengsong Zhang 2024-03-02 06:12:54 -06:00 committed by GitHub
parent 75ac3803fe
commit 461f8c5f16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 324 additions and 39 deletions

View File

@ -19,6 +19,16 @@ class BatchHijack:
self.postprocess_batch_callbacks = [self.on_postprocess_batch]
def img2img_process_batch_hijack(self, p, *args, **kwargs):
try:
from scripts.animatediff_utils import get_animatediff_arg
ad_params = get_animatediff_arg(p)
if ad_params and ad_params.enable:
ad_params.is_i2i_batch = True
from scripts.animatediff_i2ibatch import animatediff_i2i_batch
return animatediff_i2i_batch(p, *args, **kwargs)
except ImportError:
pass
cn_is_batch, batches, output_dir, _ = get_cn_batches(p)
if not cn_is_batch:
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
@ -31,6 +41,14 @@ class BatchHijack:
self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
def processing_process_images_hijack(self, p, *args, **kwargs):
try:
from scripts.animatediff_utils import get_animatediff_arg
ad_params = get_animatediff_arg(p)
if ad_params and ad_params.enable:
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
except ImportError:
pass
if self.is_batch:
# we are in img2img batch tab, do a single batch iteration
return self.process_images_cn_batch(p, *args, **kwargs)

View File

@ -12,11 +12,13 @@ class ImageEmbed(NamedTuple):
"""Image embed for a single image."""
cond_emb: torch.Tensor
uncond_emb: torch.Tensor
bypass_average: bool = False
def eval(self, cond_mark: torch.Tensor) -> torch.Tensor:
assert cond_mark.ndim == 4
assert self.cond_emb.ndim == self.uncond_emb.ndim == 3
assert self.cond_emb.shape[0] == self.uncond_emb.shape[0] == 1
assert self.uncond_emb.shape[0] == 1 or self.cond_emb.shape[0] == self.uncond_emb.shape[0]
assert self.cond_emb.shape[0] == 1 or self.cond_emb.shape[0] == cond_mark.shape[0]
cond_mark = cond_mark[:, :, :, 0].to(self.cond_emb)
device = cond_mark.device
dtype = cond_mark.dtype
@ -26,7 +28,7 @@ class ImageEmbed(NamedTuple):
)
def average_of(*args: List[Tuple[torch.Tensor, torch.Tensor]]) -> "ImageEmbed":
conds, unconds = zip(*args)
conds, unconds, _ = zip(*args)
def average_tensors(tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.sum(torch.stack(tensors), dim=0) / len(tensors)
return ImageEmbed(average_tensors(conds), average_tensors(unconds))
@ -603,11 +605,14 @@ class PlugableIPAdapter(torch.nn.Module):
self.dtype = dtype
self.ipadapter.to(device, dtype=self.dtype)
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
if getattr(preprocessor_outputs, "bypass_average", False):
self.image_emb = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
# From https://github.com/laksjdjf/IPAdapter-ComfyUI
if not self.sdxl:
number = 0 # index of to_kvs

View File

@ -3,7 +3,7 @@ import tracemalloc
import os
import logging
from collections import OrderedDict
from copy import copy
from copy import copy, deepcopy
from typing import Dict, Optional, Tuple, List, Union
import modules.scripts as scripts
from modules import shared, devices, script_callbacks, processing, masking, images
@ -16,7 +16,7 @@ from scripts import global_state, hook, external_code, batch_hijack, controlnet_
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.processor import HWC3, preprocessor_sliders_config
from scripts.controlnet_lllite import clear_all_lllite
from scripts.controlmodel_ipadapter import clear_all_ip_adapter
from scripts.controlmodel_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
@ -116,7 +116,7 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
elif isinstance(image['mask'], str):
if os.path.exists(image['mask']):
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8')
elif image['mask']:
image['mask'] = external_code.to_base64_nparray(image['mask'])
else:
@ -614,6 +614,40 @@ class Script(scripts.Script, metaclass=(
# 4 input image sources.
p_image_control = getattr(p, "image_control", None)
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
# AnimateDiff + ControlNet batch processing.
unit_is_ad_batch = getattr(unit, "animatediff_batch", False)
if unit_is_ad_batch:
batch_parameters = unit.batch_images.split("\n")
batch_image_dir = batch_parameters[0]
logger.info(f"AnimateDiff + ControlNet {unit.module} receive the following parameters:")
logger.info(f"\tbatch control images: {batch_image_dir}")
for ad_cn_batch_parameter in batch_parameters[1:]:
if ad_cn_batch_parameter.startswith("mask:"):
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:"):].strip()
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
elif ad_cn_batch_parameter.startswith("keyframe:"):
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:"):].strip()
unit.batch_keyframe_idx = [int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(',')]
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
batch_image_files = shared.listfiles(batch_image_dir)
for batch_modifier in getattr(unit, 'batch_modifiers', []):
batch_image_files = batch_modifier(batch_image_files, p)
unit.batch_image_files = batch_image_files
unit.image = []
for idx, image_path in enumerate(batch_image_files):
mask_path = None
if getattr(unit, "batch_mask_dir", None) is not None:
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
if len(batch_mask_files) >= len(batch_image_files):
mask_path = batch_mask_files[idx]
else:
mask_path = batch_mask_files[0]
unit.image.append({
"image": image_path,
"mask": mask_path,
})
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
@ -635,6 +669,14 @@ class Script(scripts.Script, metaclass=(
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
if unit_is_ad_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
for idx in range(len(input_image)):
while len(image[idx]['mask'].shape) < 3:
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
if 'inpaint' in unit.module:
color = HWC3(image[idx]["image"])
alpha = image[idx]['mask'][:, :, 0:1]
input_image[idx] = np.concatenate([color, alpha], axis=2)
else:
input_image = HWC3(decode_image(image['image']))
if 'mask' in image and image['mask'] is not None:
@ -894,14 +936,21 @@ class Script(scripts.Script, metaclass=(
model_net, control_model_type = Script.load_control_model(p, unet, unit.model)
model_net.reset()
if model_net is not None and getattr(devices, "fp8", False) and control_model_type == ControlModelType.ControlNet:
for _module in model_net.modules(): # FIXME: let's only apply fp8 to ControlNet for now
if isinstance(_module, (torch.nn.Conv2d, torch.nn.Linear)):
_module.to(torch.float8_e4m3fn)
if control_model_type == ControlModelType.ControlLoRA:
control_lora = model_net.control_model
bind_control_lora(unet, control_lora)
p.controlnet_control_loras.append(control_lora)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
is_cn_ad_batch = getattr(unit, "animatediff_batch", False)
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
if isinstance(input_image, list):
assert unit.accepts_multiple_inputs()
assert unit.accepts_multiple_inputs() or is_cn_ad_batch
input_images = input_image
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
@ -909,14 +958,15 @@ class Script(scripts.Script, metaclass=(
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
# inpaint_only+lama is special and required outpaint fix
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_image,
target_H=h,
target_W=w,
resize_mode=resize_mode,
)
input_images = [input_image]
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_images[0],
target_H=h,
target_W=w,
resize_mode=resize_mode,
)
# Preprocessor result may depend on numpy random operations, use the
# random seed in `StableDiffusionProcessing` to make the
# preprocessor result reproducable.
@ -964,12 +1014,88 @@ class Script(scripts.Script, metaclass=(
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
if is_image and is_cn_ad_batch: # AnimateDiff save VRAM
control = control.cpu()
if hr_control is not None:
hr_control = hr_control.cpu()
return control, hr_control
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images]))
if len(controls) == len(hr_controls) == 1:
def optional_tqdm(iterable, use_tqdm=is_cn_ad_batch):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
control = controls[0]
hr_control = hr_controls[0]
elif is_cn_ad_batch or control_model_type in [ControlModelType.SparseCtrl]:
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs():
ip_adapter_image_emb_cond = []
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
for c in cc:
c = model_net.get_image_emb(c) # noqa
ip_adapter_image_emb_cond.append(c.cond_emb)
c_cond = torch.cat(ip_adapter_image_emb_cond, dim=0)
c = ImageEmbed(c_cond, c.uncond_emb, True)
else:
c = torch.cat(cc, dim=0)
# SparseCtrl keyframe need to encode control image with VAE
if control_model_type == ControlModelType.SparseCtrl and \
model_net.control_model.use_simplified_condition_embedding: # noqa
c = UnetHook.call_vae_using_process(p, c)
# handle key frame control for different control methods
if cn_ad_keyframe_idx is not None or control_model_type in [ControlModelType.SparseCtrl]:
if control_model_type == ControlModelType.SparseCtrl:
# sparsectrl has its own embed generator
from scripts.controlnet_sparsectrl import SparseCtrl
if cn_ad_keyframe_idx is None:
cn_ad_keyframe_idx = [0]
logger.info(f"SparseCtrl: control images will be applied to frames: {cn_ad_keyframe_idx}")
else:
logger.info(f"SparseCtrl: control images will be applied to frames: {cn_ad_keyframe_idx}")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
elif unit.accepts_multiple_inputs():
# ip-adapter should do prompt travel
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
from scripts.animatediff_utils import get_animatediff_arg
ip_adapter_emb = c
c = c.cond_emb
c_full = torch.zeros((p.batch_size, *c.shape[1:]), dtype=c.dtype, device=c.device)
for i, idx in enumerate(cn_ad_keyframe_idx[:-1]):
c_full[idx:cn_ad_keyframe_idx[i + 1]] = c[i]
c_full[cn_ad_keyframe_idx[-1]:] = c[-1]
ad_params = get_animatediff_arg(p)
prompt_scheduler = deepcopy(ad_params.prompt_scheduler)
prompt_scheduler.prompt_map = {i: "" for i in cn_ad_keyframe_idx}
prompt_closed_loop = (ad_params.video_length > ad_params.batch_size) and (ad_params.closed_loop in ['R+P', 'A'])
c_full = prompt_scheduler.multi_cond(c_full, prompt_closed_loop)
if shared.opts.batch_cond_uncond:
c_full = torch.cat([c_full, c_full], dim=0)
c = ImageEmbed(c_full, ip_adapter_emb.uncond_emb, True)
else:
# normal CN should insert empty frames
logger.info(f"ControlNet: control images will be applied to frames: {cn_ad_keyframe_idx} where")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
c_full = torch.zeros((p.batch_size, *c.shape[1:]), dtype=c.dtype, device=c.device)
c_full[cn_ad_keyframe_idx] = c
c = c_full
# handle batch condition and unconditional
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs():
c = torch.cat([c, c], dim=0)
return c
control = ad_process_control(controls)
hr_control = ad_process_control(hr_controls) if hr_controls[0] is not None else None
if control_model_type == ControlModelType.SparseCtrl:
control_model_type = ControlModelType.ControlNet
else:
control = controls
hr_control = hr_controls
@ -1008,22 +1134,24 @@ class Script(scripts.Script, metaclass=(
final_inpaint_feed = hr_control if hr_control is not None else control
final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
final_inpaint_mask = final_inpaint_feed[:, 3, :, :].astype(np.float32)
final_inpaint_raw = final_inpaint_feed[:, :3].astype(np.float32)
sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
_, Hmask, Wmask = final_inpaint_mask.shape
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[:, None]
_, _, Hmask, Wmask = final_inpaint_mask.shape
final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
def inpaint_only_post_processing(x):
def inpaint_only_post_processing(x, i):
if i >= final_inpaint_raw.shape[0]:
i = 0
_, H, W = x.shape
if Hmask != H or Wmask != W:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
r = final_inpaint_raw.to(x.dtype).to(x.device)
m = final_inpaint_mask.to(x.dtype).to(x.device)
r = final_inpaint_raw[i].to(x.dtype).to(x.device)
m = final_inpaint_mask[i].to(x.dtype).to(x.device)
y = m * x.clip(0, 1) + (1 - m) * r
y = y.clip(0, 1)
return y
@ -1034,13 +1162,15 @@ class Script(scripts.Script, metaclass=(
final_feed = hr_control if hr_control is not None else control
final_feed = final_feed.detach().cpu().numpy()
final_feed = np.ascontiguousarray(final_feed).copy()
final_feed = final_feed[0, 0, :, :].astype(np.float32)
final_feed = final_feed[:, 0, :, :].astype(np.float32)
final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
Hfeed, Wfeed = final_feed.shape
_, Hfeed, Wfeed = final_feed.shape
if 'luminance' in unit.module:
def recolor_luminance_post_processing(x):
def recolor_luminance_post_processing(x, i):
if i >= final_feed.shape[0]:
i = 0
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
@ -1048,7 +1178,7 @@ class Script(scripts.Script, metaclass=(
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
h[:, :, 0] = final_feed
h[:, :, 0] = final_feed[i]
h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
@ -1058,7 +1188,9 @@ class Script(scripts.Script, metaclass=(
if 'intensity' in unit.module:
def recolor_intensity_post_processing(x):
def recolor_intensity_post_processing(x, i):
if i >= final_feed.shape[0]:
i = 0
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
@ -1066,7 +1198,7 @@ class Script(scripts.Script, metaclass=(
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
h[:, :, 2] = final_feed
h[:, :, 2] = final_feed[i]
h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
@ -1157,7 +1289,7 @@ class Script(scripts.Script, metaclass=(
images = kwargs.get('images', [])
for post_processor in self.post_processors:
for i in range(len(images)):
images[i] = post_processor(images[i])
images[i] = post_processor(images[i], i)
return
def postprocess(self, p, processed, *args):

View File

@ -8,6 +8,7 @@ from modules import devices
from scripts.adapter import PlugableAdapter, Adapter, StyleAdapter, Adapter_light
from scripts.controlnet_lllite import PlugableControlLLLite
from scripts.cldm import PlugableControlModel
from scripts.controlnet_sparsectrl import PlugableSparseCtrlModel
from scripts.controlmodel_ipadapter import PlugableIPAdapter
from scripts.logging import logger
from scripts.controlnet_diffusers import convert_from_diffuser_state_dict
@ -132,6 +133,21 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
network.to(devices.dtype_unet)
return ControlModel(network, ControlModelType.ControlLoRA)
if "down_blocks.0.motion_modules.0.temporal_transformer.norm.weight" in state_dict: # sparsectrl
config = copy.deepcopy(controlnet_default_config)
if "input_hint_block.0.weight" in state_dict: # rgb
config['use_simplified_condition_embedding'] = True
config['conditioning_channels'] = 5
else: # scribble
config['use_simplified_condition_embedding'] = False
config['conditioning_channels'] = 4
config['use_fp16'] = devices.dtype_unet == torch.float16
network = PlugableSparseCtrlModel(config, state_dict)
network.to(devices.dtype_unet)
return ControlModel(network, ControlModelType.SparseCtrl)
if "controlnet_cond_embedding.conv_in.weight" in state_dict: # diffusers
state_dict = convert_from_diffuser_state_dict(state_dict)

View File

@ -0,0 +1,99 @@
from typing import Tuple, List
import torch
from torch import nn
from torch.nn import functional as F
from scripts.cldm import PlugableControlModel, ControlNet, zero_module, conv_nd, TimestepEmbedSequential
class PlugableSparseCtrlModel(PlugableControlModel):
def __init__(self, config, state_dict=None):
nn.Module.__init__(self)
self.config = config
self.control_model = SparseCtrl(**self.config).cpu()
if state_dict is not None:
self.control_model.load_state_dict(state_dict, strict=False)
self.gpu_component = None
class CondEmbed(nn.Module):
def __init__(
self,
dims: int,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = conv_nd(dims, conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(conv_nd(dims, channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(conv_nd(dims, channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(conv_nd(dims, block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1))
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class SparseCtrl(ControlNet):
def __init__(self, use_simplified_condition_embedding=True, conditioning_channels=4, **kwargs):
super().__init__(hint_channels=1, **kwargs) # we don't need hint_channels, but we need to set it to 1 to avoid errors
self.use_simplified_condition_embedding = use_simplified_condition_embedding
if use_simplified_condition_embedding:
self.input_hint_block = TimestepEmbedSequential(
zero_module(conv_nd(self.dims, conditioning_channels, kwargs.get("model_channels", 320), kernel_size=3, padding=1)))
else:
self.input_hint_block = TimestepEmbedSequential(
CondEmbed(
self.dims, kwargs.get("model_channels", 320),
conditioning_channels=conditioning_channels,))
def load_state_dict(self, state_dict, strict=False):
mm_dict = {}
cn_dict = {}
for k, v in state_dict.items():
if "motion_modules" in k:
mm_dict[k] = v
else:
cn_dict[k] = v
super().load_state_dict(cn_dict, strict=True)
from scripts.animatediff_mm import MotionWrapper, MotionModuleType
sparsectrl_mm = MotionWrapper("", "", MotionModuleType.SparseCtrl)
sparsectrl_mm.load_state_dict(mm_dict, strict=True)
for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
mm_inject = getattr(sparsectrl_mm.down_blocks[mm_idx0], "motion_modules")[mm_idx1]
self.input_blocks[unet_idx].append(mm_inject)
@staticmethod
def create_cond_mask(control_image_index: List[int], control_image_latents: torch.Tensor, video_length: int):
hint_cond = torch.zeros((video_length, *control_image_latents.shape[1:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
hint_cond[control_image_index] = control_image_latents[:len(control_image_index)]
hint_cond_mask = torch.zeros((hint_cond.shape[0], 1, *hint_cond.shape[2:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
hint_cond_mask[control_image_index] = 1.0
return torch.cat([hint_cond, hint_cond_mask], dim=1)
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
return super().forward(torch.zeros_like(x, device=x.device), hint, timesteps, context, y=y, **kwargs)

View File

@ -1,6 +1,6 @@
from scripts.logging import logger
version_flag = 'v1.1.440'
version_flag = 'v1.1.441'
logger.info(f"ControlNet {version_flag}")
# A smart trick to know if user has updated as well as if user has restarted terminal.

View File

@ -62,6 +62,7 @@ class ControlModelType(Enum):
IPAdapter = "IPAdapter, Hu Ye"
Controlllite = "Controlllite, Kohya"
InstantID = "InstantID, Qixun Wang"
SparseCtrl = "SparseCtrl, Yuwei Guo"
def is_controlnet(self) -> bool:
"""Returns whether the control model should be treated as ControlNet."""

View File

@ -319,6 +319,10 @@ def select_control_type(
filtered_preprocessor_list += [
x for x in preprocessor_list if "invert" in x.lower()
]
if pattern in ["sparsectrl"]:
filtered_preprocessor_list += [
x for x in preprocessor_list if "scribble" in x.lower()
]
filtered_model_list = [
model for model in all_models
if model.lower() == "none" or

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from scripts.logging import logger
from scripts.enums import ControlModelType, AutoMachine, HiResFixOption
from scripts.controlmodel_ipadapter import ImageEmbed
from scripts.controlnet_sparsectrl import SparseCtrl
from modules import devices, lowvram, shared, scripts
from ldm.modules.diffusionmodules.util import timestep_embedding, make_beta_schedule
@ -384,8 +385,12 @@ class UnetHook(nn.Module):
vae_output = vae_cache.get(x)
if vae_output is None:
with devices.autocast():
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
vae_output = torch.stack([
p.sd_model.get_first_stage_encoding(
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
)[0].to(img.device)
for img in x
])
if torch.all(torch.isnan(vae_output)).item():
logger.info('ControlNet find Nans in the VAE encoding. \n '
'Now ControlNet will automatically retry.\n '
@ -393,8 +398,12 @@ class UnetHook(nn.Module):
devices.dtype_vae = torch.float32
x = x.to(devices.dtype_vae)
p.sd_model.first_stage_model.to(devices.dtype_vae)
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
vae_output = torch.stack([
p.sd_model.get_first_stage_encoding(
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
)[0].to(img.device)
for img in x
])
vae_cache.set(x, vae_output)
logger.info(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
latent = vae_output
@ -571,7 +580,7 @@ class UnetHook(nn.Module):
controlnet_context = context
# ControlNet inpaint protocol
if hint.shape[1] == 4:
if hint.shape[1] == 4 and not isinstance(control_model, SparseCtrl):
c = hint[:, 0:3, :, :]
m = hint[:, 3:4, :, :]
m = (m > 0.5).float()

View File

@ -1316,6 +1316,7 @@ preprocessor_filters = {
"T2I-Adapter": "none",
"IP-Adapter": "ip-adapter_clip_sd15",
"Instant_ID": "instant_id",
"SparseCtrl": "none",
}
preprocessor_filters_aliases = {