diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index c6a37df..aad9d09 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -222,6 +222,10 @@ class ControlNetUnit: "inpaint_crop_input_image", ] + @property + def is_animate_diff_batch(self) -> bool: + return getattr(self, "animatediff_batch", False) + def to_base64_nparray(encoding: str): """ diff --git a/scripts/animate_diff/batch.py b/scripts/animate_diff/batch.py new file mode 100644 index 0000000..72bd9db --- /dev/null +++ b/scripts/animate_diff/batch.py @@ -0,0 +1,48 @@ +from scripts.external_code import ControlNetUnit +from scripts.logging import logger +from modules.processing import StableDiffusionProcessing +from modules import shared + + +def add_animate_diff_batch_input( + p: StableDiffusionProcessing, unit: ControlNetUnit +) -> ControlNetUnit: + """AnimateDiff + ControlNet batch processing.""" + assert unit.is_animate_diff_batch + + batch_parameters = unit.batch_images.split("\n") + batch_image_dir = batch_parameters[0] + logger.info( + f"AnimateDiff + ControlNet {unit.module} receive the following parameters:" + ) + logger.info(f"\tbatch control images: {batch_image_dir}") + for ad_cn_batch_parameter in batch_parameters[1:]: + if ad_cn_batch_parameter.startswith("mask:"): + unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:") :].strip() + logger.info(f"\tbatch control mask: {unit.batch_mask_dir}") + elif ad_cn_batch_parameter.startswith("keyframe:"): + unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:") :].strip() + unit.batch_keyframe_idx = [ + int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(",") + ] + logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}") + batch_image_files = shared.listfiles(batch_image_dir) + for batch_modifier in getattr(unit, "batch_modifiers", []): + batch_image_files = batch_modifier(batch_image_files, p) + unit.batch_image_files = batch_image_files + unit.image = [] + for idx, image_path in enumerate(batch_image_files): + mask_path = None + if getattr(unit, "batch_mask_dir", None) is not None: + batch_mask_files = shared.listfiles(unit.batch_mask_dir) + if len(batch_mask_files) >= len(batch_image_files): + mask_path = batch_mask_files[idx] + else: + mask_path = batch_mask_files[0] + unit.image.append( + { + "image": image_path, + "mask": mask_path, + } + ) + return unit diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 56ac9fe..e38655d 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -24,6 +24,7 @@ from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOpti from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from scripts.controlnet_ui.photopea import Photopea from scripts.logging import logger +from scripts.animate_diff.batch import add_animate_diff_batch_input from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing from modules.images import save_image from scripts.infotext import Infotext @@ -621,40 +622,6 @@ class Script(scripts.Script, metaclass=( # 4 input image sources. p_image_control = getattr(p, "image_control", None) p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx) - - # AnimateDiff + ControlNet batch processing. - unit_is_ad_batch = getattr(unit, "animatediff_batch", False) - if unit_is_ad_batch: - batch_parameters = unit.batch_images.split("\n") - batch_image_dir = batch_parameters[0] - logger.info(f"AnimateDiff + ControlNet {unit.module} receive the following parameters:") - logger.info(f"\tbatch control images: {batch_image_dir}") - for ad_cn_batch_parameter in batch_parameters[1:]: - if ad_cn_batch_parameter.startswith("mask:"): - unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:"):].strip() - logger.info(f"\tbatch control mask: {unit.batch_mask_dir}") - elif ad_cn_batch_parameter.startswith("keyframe:"): - unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:"):].strip() - unit.batch_keyframe_idx = [int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(',')] - logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}") - batch_image_files = shared.listfiles(batch_image_dir) - for batch_modifier in getattr(unit, 'batch_modifiers', []): - batch_image_files = batch_modifier(batch_image_files, p) - unit.batch_image_files = batch_image_files - unit.image = [] - for idx, image_path in enumerate(batch_image_files): - mask_path = None - if getattr(unit, "batch_mask_dir", None) is not None: - batch_mask_files = shared.listfiles(unit.batch_mask_dir) - if len(batch_mask_files) >= len(batch_image_files): - mask_path = batch_mask_files[idx] - else: - mask_path = batch_mask_files[0] - unit.image.append({ - "image": image_path, - "mask": mask_path, - }) - image = parse_unit_image(unit) a1111_image = getattr(p, "init_images", [None])[0] @@ -676,7 +643,7 @@ class Script(scripts.Script, metaclass=( # Add mask logic if later there is a processor that accepts mask # on multiple inputs. input_image = [HWC3(decode_image(img['image'])) for img in image] - if unit_is_ad_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None: + if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None: for idx in range(len(input_image)): while len(image[idx]['mask'].shape) < 3: image[idx]['mask'] = image[idx]['mask'][..., np.newaxis] @@ -944,11 +911,12 @@ class Script(scripts.Script, metaclass=( bind_control_lora(unet, control_lora) p.controlnet_control_loras.append(control_lora) + if unit.is_animate_diff_batch: + unit = add_animate_diff_batch_input(p, unit) input_image, resize_mode = Script.choose_input_image(p, unit, idx) - is_cn_ad_batch = getattr(unit, "animatediff_batch", False) cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None) if isinstance(input_image, list): - assert unit.accepts_multiple_inputs() or is_cn_ad_batch + assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch input_images = input_image else: # Following operations are only for single input image. input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) @@ -1013,14 +981,14 @@ class Script(scripts.Script, metaclass=( if control_model_type == ControlModelType.ReVision: control = control['image_embeds'] - if is_image and is_cn_ad_batch: # AnimateDiff save VRAM + if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM control = control.cpu() if hr_control is not None: hr_control = hr_control.cpu() return control, hr_control - def optional_tqdm(iterable, use_tqdm=is_cn_ad_batch): + def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch): from tqdm import tqdm return tqdm(iterable) if use_tqdm else iterable @@ -1028,7 +996,7 @@ class Script(scripts.Script, metaclass=( if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]: control = controls[0] hr_control = hr_controls[0] - elif is_cn_ad_batch or control_model_type in [ControlModelType.SparseCtrl]: + elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]: def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx): if unit.accepts_multiple_inputs(): ip_adapter_image_emb_cond = []