parent
faddb8bb9c
commit
807a883ed6
|
|
@ -222,6 +222,10 @@ class ControlNetUnit:
|
|||
"inpaint_crop_input_image",
|
||||
]
|
||||
|
||||
@property
|
||||
def is_animate_diff_batch(self) -> bool:
|
||||
return getattr(self, "animatediff_batch", False)
|
||||
|
||||
|
||||
def to_base64_nparray(encoding: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
from scripts.external_code import ControlNetUnit
|
||||
from scripts.logging import logger
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared
|
||||
|
||||
|
||||
def add_animate_diff_batch_input(
|
||||
p: StableDiffusionProcessing, unit: ControlNetUnit
|
||||
) -> ControlNetUnit:
|
||||
"""AnimateDiff + ControlNet batch processing."""
|
||||
assert unit.is_animate_diff_batch
|
||||
|
||||
batch_parameters = unit.batch_images.split("\n")
|
||||
batch_image_dir = batch_parameters[0]
|
||||
logger.info(
|
||||
f"AnimateDiff + ControlNet {unit.module} receive the following parameters:"
|
||||
)
|
||||
logger.info(f"\tbatch control images: {batch_image_dir}")
|
||||
for ad_cn_batch_parameter in batch_parameters[1:]:
|
||||
if ad_cn_batch_parameter.startswith("mask:"):
|
||||
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:") :].strip()
|
||||
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
|
||||
elif ad_cn_batch_parameter.startswith("keyframe:"):
|
||||
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:") :].strip()
|
||||
unit.batch_keyframe_idx = [
|
||||
int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(",")
|
||||
]
|
||||
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
|
||||
batch_image_files = shared.listfiles(batch_image_dir)
|
||||
for batch_modifier in getattr(unit, "batch_modifiers", []):
|
||||
batch_image_files = batch_modifier(batch_image_files, p)
|
||||
unit.batch_image_files = batch_image_files
|
||||
unit.image = []
|
||||
for idx, image_path in enumerate(batch_image_files):
|
||||
mask_path = None
|
||||
if getattr(unit, "batch_mask_dir", None) is not None:
|
||||
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
|
||||
if len(batch_mask_files) >= len(batch_image_files):
|
||||
mask_path = batch_mask_files[idx]
|
||||
else:
|
||||
mask_path = batch_mask_files[0]
|
||||
unit.image.append(
|
||||
{
|
||||
"image": image_path,
|
||||
"mask": mask_path,
|
||||
}
|
||||
)
|
||||
return unit
|
||||
|
|
@ -24,6 +24,7 @@ from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOpti
|
|||
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
|
||||
from scripts.controlnet_ui.photopea import Photopea
|
||||
from scripts.logging import logger
|
||||
from scripts.animate_diff.batch import add_animate_diff_batch_input
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
|
||||
from modules.images import save_image
|
||||
from scripts.infotext import Infotext
|
||||
|
|
@ -621,40 +622,6 @@ class Script(scripts.Script, metaclass=(
|
|||
# 4 input image sources.
|
||||
p_image_control = getattr(p, "image_control", None)
|
||||
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
|
||||
|
||||
# AnimateDiff + ControlNet batch processing.
|
||||
unit_is_ad_batch = getattr(unit, "animatediff_batch", False)
|
||||
if unit_is_ad_batch:
|
||||
batch_parameters = unit.batch_images.split("\n")
|
||||
batch_image_dir = batch_parameters[0]
|
||||
logger.info(f"AnimateDiff + ControlNet {unit.module} receive the following parameters:")
|
||||
logger.info(f"\tbatch control images: {batch_image_dir}")
|
||||
for ad_cn_batch_parameter in batch_parameters[1:]:
|
||||
if ad_cn_batch_parameter.startswith("mask:"):
|
||||
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:"):].strip()
|
||||
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
|
||||
elif ad_cn_batch_parameter.startswith("keyframe:"):
|
||||
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:"):].strip()
|
||||
unit.batch_keyframe_idx = [int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(',')]
|
||||
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
|
||||
batch_image_files = shared.listfiles(batch_image_dir)
|
||||
for batch_modifier in getattr(unit, 'batch_modifiers', []):
|
||||
batch_image_files = batch_modifier(batch_image_files, p)
|
||||
unit.batch_image_files = batch_image_files
|
||||
unit.image = []
|
||||
for idx, image_path in enumerate(batch_image_files):
|
||||
mask_path = None
|
||||
if getattr(unit, "batch_mask_dir", None) is not None:
|
||||
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
|
||||
if len(batch_mask_files) >= len(batch_image_files):
|
||||
mask_path = batch_mask_files[idx]
|
||||
else:
|
||||
mask_path = batch_mask_files[0]
|
||||
unit.image.append({
|
||||
"image": image_path,
|
||||
"mask": mask_path,
|
||||
})
|
||||
|
||||
image = parse_unit_image(unit)
|
||||
a1111_image = getattr(p, "init_images", [None])[0]
|
||||
|
||||
|
|
@ -676,7 +643,7 @@ class Script(scripts.Script, metaclass=(
|
|||
# Add mask logic if later there is a processor that accepts mask
|
||||
# on multiple inputs.
|
||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
||||
if unit_is_ad_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
|
||||
if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
|
||||
for idx in range(len(input_image)):
|
||||
while len(image[idx]['mask'].shape) < 3:
|
||||
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
|
||||
|
|
@ -944,11 +911,12 @@ class Script(scripts.Script, metaclass=(
|
|||
bind_control_lora(unet, control_lora)
|
||||
p.controlnet_control_loras.append(control_lora)
|
||||
|
||||
if unit.is_animate_diff_batch:
|
||||
unit = add_animate_diff_batch_input(p, unit)
|
||||
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
|
||||
is_cn_ad_batch = getattr(unit, "animatediff_batch", False)
|
||||
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
|
||||
if isinstance(input_image, list):
|
||||
assert unit.accepts_multiple_inputs() or is_cn_ad_batch
|
||||
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
|
||||
input_images = input_image
|
||||
else: # Following operations are only for single input image.
|
||||
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
|
||||
|
|
@ -1013,14 +981,14 @@ class Script(scripts.Script, metaclass=(
|
|||
if control_model_type == ControlModelType.ReVision:
|
||||
control = control['image_embeds']
|
||||
|
||||
if is_image and is_cn_ad_batch: # AnimateDiff save VRAM
|
||||
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
|
||||
control = control.cpu()
|
||||
if hr_control is not None:
|
||||
hr_control = hr_control.cpu()
|
||||
|
||||
return control, hr_control
|
||||
|
||||
def optional_tqdm(iterable, use_tqdm=is_cn_ad_batch):
|
||||
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
|
||||
from tqdm import tqdm
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
|
|
@ -1028,7 +996,7 @@ class Script(scripts.Script, metaclass=(
|
|||
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
|
||||
control = controls[0]
|
||||
hr_control = hr_controls[0]
|
||||
elif is_cn_ad_batch or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
|
||||
if unit.accepts_multiple_inputs():
|
||||
ip_adapter_image_emb_cond = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue