Move ad batch input (#2720)

* Move ad batch input

* nit

* nit
pull/2707/head^2
Chenlei Hu 2024-03-31 03:31:12 +00:00 committed by GitHub
parent faddb8bb9c
commit 807a883ed6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 40 deletions

View File

@ -222,6 +222,10 @@ class ControlNetUnit:
"inpaint_crop_input_image",
]
@property
def is_animate_diff_batch(self) -> bool:
return getattr(self, "animatediff_batch", False)
def to_base64_nparray(encoding: str):
"""

View File

@ -0,0 +1,48 @@
from scripts.external_code import ControlNetUnit
from scripts.logging import logger
from modules.processing import StableDiffusionProcessing
from modules import shared
def add_animate_diff_batch_input(
p: StableDiffusionProcessing, unit: ControlNetUnit
) -> ControlNetUnit:
"""AnimateDiff + ControlNet batch processing."""
assert unit.is_animate_diff_batch
batch_parameters = unit.batch_images.split("\n")
batch_image_dir = batch_parameters[0]
logger.info(
f"AnimateDiff + ControlNet {unit.module} receive the following parameters:"
)
logger.info(f"\tbatch control images: {batch_image_dir}")
for ad_cn_batch_parameter in batch_parameters[1:]:
if ad_cn_batch_parameter.startswith("mask:"):
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:") :].strip()
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
elif ad_cn_batch_parameter.startswith("keyframe:"):
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:") :].strip()
unit.batch_keyframe_idx = [
int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(",")
]
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
batch_image_files = shared.listfiles(batch_image_dir)
for batch_modifier in getattr(unit, "batch_modifiers", []):
batch_image_files = batch_modifier(batch_image_files, p)
unit.batch_image_files = batch_image_files
unit.image = []
for idx, image_path in enumerate(batch_image_files):
mask_path = None
if getattr(unit, "batch_mask_dir", None) is not None:
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
if len(batch_mask_files) >= len(batch_image_files):
mask_path = batch_mask_files[idx]
else:
mask_path = batch_mask_files[0]
unit.image.append(
{
"image": image_path,
"mask": mask_path,
}
)
return unit

View File

@ -24,6 +24,7 @@ from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOpti
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
from scripts.animate_diff.batch import add_animate_diff_batch_input
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
from modules.images import save_image
from scripts.infotext import Infotext
@ -621,40 +622,6 @@ class Script(scripts.Script, metaclass=(
# 4 input image sources.
p_image_control = getattr(p, "image_control", None)
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
# AnimateDiff + ControlNet batch processing.
unit_is_ad_batch = getattr(unit, "animatediff_batch", False)
if unit_is_ad_batch:
batch_parameters = unit.batch_images.split("\n")
batch_image_dir = batch_parameters[0]
logger.info(f"AnimateDiff + ControlNet {unit.module} receive the following parameters:")
logger.info(f"\tbatch control images: {batch_image_dir}")
for ad_cn_batch_parameter in batch_parameters[1:]:
if ad_cn_batch_parameter.startswith("mask:"):
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:"):].strip()
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
elif ad_cn_batch_parameter.startswith("keyframe:"):
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:"):].strip()
unit.batch_keyframe_idx = [int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(',')]
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
batch_image_files = shared.listfiles(batch_image_dir)
for batch_modifier in getattr(unit, 'batch_modifiers', []):
batch_image_files = batch_modifier(batch_image_files, p)
unit.batch_image_files = batch_image_files
unit.image = []
for idx, image_path in enumerate(batch_image_files):
mask_path = None
if getattr(unit, "batch_mask_dir", None) is not None:
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
if len(batch_mask_files) >= len(batch_image_files):
mask_path = batch_mask_files[idx]
else:
mask_path = batch_mask_files[0]
unit.image.append({
"image": image_path,
"mask": mask_path,
})
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
@ -676,7 +643,7 @@ class Script(scripts.Script, metaclass=(
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
if unit_is_ad_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
for idx in range(len(input_image)):
while len(image[idx]['mask'].shape) < 3:
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
@ -944,11 +911,12 @@ class Script(scripts.Script, metaclass=(
bind_control_lora(unet, control_lora)
p.controlnet_control_loras.append(control_lora)
if unit.is_animate_diff_batch:
unit = add_animate_diff_batch_input(p, unit)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
is_cn_ad_batch = getattr(unit, "animatediff_batch", False)
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
if isinstance(input_image, list):
assert unit.accepts_multiple_inputs() or is_cn_ad_batch
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
input_images = input_image
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
@ -1013,14 +981,14 @@ class Script(scripts.Script, metaclass=(
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
if is_image and is_cn_ad_batch: # AnimateDiff save VRAM
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
control = control.cpu()
if hr_control is not None:
hr_control = hr_control.cpu()
return control, hr_control
def optional_tqdm(iterable, use_tqdm=is_cn_ad_batch):
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
@ -1028,7 +996,7 @@ class Script(scripts.Script, metaclass=(
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
control = controls[0]
hr_control = hr_controls[0]
elif is_cn_ad_batch or control_model_type in [ControlModelType.SparseCtrl]:
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs():
ip_adapter_image_emb_cond = []