49 lines
2.0 KiB
Python
49 lines
2.0 KiB
Python
from scripts.external_code import ControlNetUnit
|
|
from scripts.logging import logger
|
|
from modules.processing import StableDiffusionProcessing
|
|
from modules import shared
|
|
|
|
|
|
def add_animate_diff_batch_input(
|
|
p: StableDiffusionProcessing, unit: ControlNetUnit
|
|
) -> ControlNetUnit:
|
|
"""AnimateDiff + ControlNet batch processing."""
|
|
assert unit.is_animate_diff_batch
|
|
|
|
batch_parameters = unit.batch_images.split("\n")
|
|
batch_image_dir = batch_parameters[0]
|
|
logger.info(
|
|
f"AnimateDiff + ControlNet {unit.module} receive the following parameters:"
|
|
)
|
|
logger.info(f"\tbatch control images: {batch_image_dir}")
|
|
for ad_cn_batch_parameter in batch_parameters[1:]:
|
|
if ad_cn_batch_parameter.startswith("mask:"):
|
|
unit.batch_mask_dir = ad_cn_batch_parameter[len("mask:") :].strip()
|
|
logger.info(f"\tbatch control mask: {unit.batch_mask_dir}")
|
|
elif ad_cn_batch_parameter.startswith("keyframe:"):
|
|
unit.batch_keyframe_idx = ad_cn_batch_parameter[len("keyframe:") :].strip()
|
|
unit.batch_keyframe_idx = [
|
|
int(b_i.strip()) for b_i in unit.batch_keyframe_idx.split(",")
|
|
]
|
|
logger.info(f"\tbatch control keyframe index: {unit.batch_keyframe_idx}")
|
|
batch_image_files = shared.listfiles(batch_image_dir)
|
|
for batch_modifier in getattr(unit, "batch_modifiers", []):
|
|
batch_image_files = batch_modifier(batch_image_files, p)
|
|
unit.batch_image_files = batch_image_files
|
|
unit.image = []
|
|
for idx, image_path in enumerate(batch_image_files):
|
|
mask_path = None
|
|
if getattr(unit, "batch_mask_dir", None) is not None:
|
|
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
|
|
if len(batch_mask_files) >= len(batch_image_files):
|
|
mask_path = batch_mask_files[idx]
|
|
else:
|
|
mask_path = batch_mask_files[0]
|
|
unit.image.append(
|
|
{
|
|
"image": image_path,
|
|
"mask": mask_path,
|
|
}
|
|
)
|
|
return unit
|