549 lines
22 KiB
Python
549 lines
22 KiB
Python
from typing import List
|
|
|
|
import os
|
|
import cv2
|
|
import subprocess
|
|
import gradio as gr
|
|
|
|
from modules import shared
|
|
from modules.launch_utils import git
|
|
from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img
|
|
|
|
from scripts.animatediff_mm import mm_animatediff as motion_module
|
|
from scripts.animatediff_xyz import xyz_attrs
|
|
from scripts.animatediff_logger import logger_animatediff as logger
|
|
from scripts.animatediff_utils import get_controlnet_units, extract_frames_from_video
|
|
|
|
supported_save_formats = ["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"]
|
|
|
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(variant="tool", **kwargs)
|
|
|
|
|
|
def get_block_name(self):
|
|
return "button"
|
|
|
|
|
|
class AnimateDiffProcess:
|
|
|
|
def __init__(
|
|
self,
|
|
model="mm_sd15_v3.safetensors",
|
|
enable=False,
|
|
video_length=0,
|
|
fps=8,
|
|
loop_number=0,
|
|
closed_loop='R-P',
|
|
batch_size=16,
|
|
stride=1,
|
|
overlap=-1,
|
|
format=shared.opts.data.get("animatediff_default_save_formats", ["GIF", "PNG"]),
|
|
interp='Off',
|
|
interp_x=10,
|
|
video_source=None,
|
|
video_path='',
|
|
mask_path='',
|
|
freeinit_enable=False,
|
|
freeinit_filter="butterworth",
|
|
freeinit_ds=0.25,
|
|
freeinit_dt=0.25,
|
|
freeinit_iters=3,
|
|
latent_power=1,
|
|
latent_scale=32,
|
|
last_frame=None,
|
|
latent_power_last=1,
|
|
latent_scale_last=32,
|
|
request_id = '',
|
|
is_i2i_batch=False,
|
|
video_default=False,
|
|
prompt_scheduler=None,
|
|
):
|
|
self.model = model
|
|
self.enable = enable
|
|
self.video_length = video_length
|
|
self.fps = fps
|
|
self.loop_number = loop_number
|
|
self.closed_loop = closed_loop
|
|
self.batch_size = batch_size
|
|
self.stride = stride
|
|
self.overlap = overlap
|
|
self.format = format
|
|
self.interp = interp
|
|
self.interp_x = interp_x
|
|
self.video_source = video_source
|
|
self.video_path = video_path
|
|
self.mask_path = mask_path
|
|
self.freeinit_enable = freeinit_enable
|
|
self.freeinit_filter = freeinit_filter
|
|
self.freeinit_ds = freeinit_ds
|
|
self.freeinit_dt = freeinit_dt
|
|
self.freeinit_iters = freeinit_iters
|
|
self.latent_power = latent_power
|
|
self.latent_scale = latent_scale
|
|
self.last_frame = last_frame
|
|
self.latent_power_last = latent_power_last
|
|
self.latent_scale_last = latent_scale_last
|
|
|
|
# non-ui states
|
|
self.request_id = request_id
|
|
self.video_default = video_default
|
|
self.is_i2i_batch = is_i2i_batch
|
|
self.prompt_scheduler = prompt_scheduler
|
|
|
|
|
|
def get_list(self, is_img2img: bool):
|
|
return list(vars(self).values())[:(25 if is_img2img else 20)]
|
|
|
|
|
|
def get_dict(self, is_img2img: bool):
|
|
infotext = {
|
|
"model": self.model,
|
|
"video_length": self.video_length,
|
|
"fps": self.fps,
|
|
"loop_number": self.loop_number,
|
|
"closed_loop": self.closed_loop,
|
|
"batch_size": self.batch_size,
|
|
"stride": self.stride,
|
|
"overlap": self.overlap,
|
|
"interp": self.interp,
|
|
"interp_x": self.interp_x,
|
|
"freeinit_enable": self.freeinit_enable,
|
|
}
|
|
if self.request_id:
|
|
infotext['request_id'] = self.request_id
|
|
if motion_module.mm is not None and motion_module.mm.mm_hash is not None:
|
|
infotext['mm_hash'] = motion_module.mm.mm_hash[:8]
|
|
if is_img2img:
|
|
infotext.update({
|
|
"latent_power": self.latent_power,
|
|
"latent_scale": self.latent_scale,
|
|
"latent_power_last": self.latent_power_last,
|
|
"latent_scale_last": self.latent_scale_last,
|
|
})
|
|
|
|
try:
|
|
ad_git_tag = subprocess.check_output(
|
|
[git, "-C", motion_module.get_model_dir(), "describe", "--tags"],
|
|
shell=False, encoding='utf8').strip()
|
|
infotext['version'] = ad_git_tag
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get git tag for AnimateDiff: {e}")
|
|
|
|
infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items())
|
|
return infotext_str
|
|
|
|
|
|
def get_param_names(self, is_img2img: bool):
|
|
preserve = ["model", "enable", "video_length", "fps", "loop_number", "closed_loop", "batch_size", "stride", "overlap", "format", "interp", "interp_x"]
|
|
if is_img2img:
|
|
preserve.extend(["latent_power", "latent_power_last", "latent_scale", "latent_scale_last"])
|
|
|
|
return preserve
|
|
|
|
|
|
def _check(self):
|
|
assert (
|
|
self.video_length >= 0 and self.fps > 0
|
|
), "Video length and FPS should be positive."
|
|
assert not set(supported_save_formats[:-1]).isdisjoint(
|
|
self.format
|
|
), "At least one saving format should be selected."
|
|
|
|
|
|
def apply_xyz(self):
|
|
for k, v in xyz_attrs.items():
|
|
setattr(self, k, v)
|
|
|
|
|
|
def set_p(self, p: StableDiffusionProcessing):
|
|
self._check()
|
|
if self.video_length < self.batch_size:
|
|
p.batch_size = self.batch_size
|
|
else:
|
|
p.batch_size = self.video_length
|
|
if self.video_length == 0:
|
|
self.video_length = p.batch_size
|
|
self.video_default = True
|
|
if self.overlap == -1:
|
|
self.overlap = self.batch_size // 4
|
|
if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", True):
|
|
p.do_not_save_samples = True
|
|
|
|
cn_units = get_controlnet_units(p)
|
|
min_batch_in_cn = -1
|
|
for cn_unit in cn_units:
|
|
if not cn_unit.enabled:
|
|
continue
|
|
|
|
# batch path broadcast
|
|
if (cn_unit.input_mode.name == 'SIMPLE' and cn_unit.image is None) or \
|
|
(cn_unit.input_mode.name == 'BATCH' and not cn_unit.batch_images) or \
|
|
(cn_unit.input_mode.name == 'MERGE' and not cn_unit.batch_input_gallery):
|
|
if not self.video_path:
|
|
extract_frames_from_video(self)
|
|
cn_unit.input_mode = cn_unit.input_mode.__class__.BATCH
|
|
cn_unit.batch_images = self.video_path
|
|
|
|
# mask path broadcast
|
|
if cn_unit.input_mode.name == 'BATCH' and self.mask_path and not getattr(cn_unit, 'batch_mask_dir', False):
|
|
cn_unit.batch_mask_dir = self.mask_path
|
|
|
|
# find minimun control images in CN batch
|
|
cn_unit_batch_params = cn_unit.batch_images.split('\n')
|
|
if cn_unit.input_mode.name == 'BATCH':
|
|
cn_unit.animatediff_batch = True # for A1111 sd-webui-controlnet
|
|
if not any([cn_param.startswith("keyframe:") for cn_param in cn_unit_batch_params[1:]]):
|
|
cn_unit_batch_num = len(shared.listfiles(cn_unit_batch_params[0]))
|
|
if min_batch_in_cn == -1 or cn_unit_batch_num < min_batch_in_cn:
|
|
min_batch_in_cn = cn_unit_batch_num
|
|
|
|
if min_batch_in_cn != -1:
|
|
self.fix_video_length(p, min_batch_in_cn)
|
|
def cn_batch_modifler(batch_image_files: List[str], p: StableDiffusionProcessing):
|
|
return batch_image_files[:self.video_length]
|
|
for cn_unit in cn_units:
|
|
if cn_unit.input_mode.name == 'BATCH':
|
|
cur_batch_modifier = getattr(cn_unit, "batch_modifiers", [])
|
|
cur_batch_modifier.append(cn_batch_modifler)
|
|
cn_unit.batch_modifiers = cur_batch_modifier
|
|
self.post_setup_cn_for_i2i_batch(p)
|
|
logger.info(f"AnimateDiff + ControlNet will generate {self.video_length} frames.")
|
|
|
|
|
|
def fix_video_length(self, p: StableDiffusionProcessing, min_batch_in_cn: int):
|
|
# ensure that params.video_length <= video_length and params.batch_size <= video_length
|
|
if self.video_length > min_batch_in_cn:
|
|
self.video_length = min_batch_in_cn
|
|
p.batch_size = min_batch_in_cn
|
|
if self.batch_size > min_batch_in_cn:
|
|
self.batch_size = min_batch_in_cn
|
|
if self.video_default:
|
|
self.video_length = min_batch_in_cn
|
|
p.batch_size = min_batch_in_cn
|
|
|
|
|
|
def post_setup_cn_for_i2i_batch(self, p: StableDiffusionProcessing):
|
|
if not (self.is_i2i_batch and isinstance(p, StableDiffusionProcessingImg2Img)):
|
|
return
|
|
|
|
if len(p.init_images) > self.video_length:
|
|
p.init_images = p.init_images[:self.video_length]
|
|
if p.image_mask and isinstance(p.image_mask, list) and len(p.image_mask) > self.video_length:
|
|
p.image_mask = p.image_mask[:self.video_length]
|
|
if len(p.init_images) < self.video_length:
|
|
self.video_length = len(p.init_images)
|
|
p.batch_size = len(p.init_images)
|
|
if len(p.init_images) < self.batch_size:
|
|
self.batch_size = len(p.init_images)
|
|
|
|
|
|
class AnimateDiffUiGroup:
|
|
txt2img_submit_button = None
|
|
img2img_submit_button = None
|
|
setting_sd_model_checkpoint = None
|
|
animatediff_ui_group = []
|
|
|
|
def __init__(self):
|
|
self.params = AnimateDiffProcess()
|
|
AnimateDiffUiGroup.animatediff_ui_group.append(self)
|
|
|
|
# Free-init
|
|
self.filter_type_list = [
|
|
"butterworth",
|
|
"gaussian",
|
|
"box",
|
|
"ideal"
|
|
]
|
|
|
|
|
|
def get_model_list(self):
|
|
model_dir = motion_module.get_model_dir()
|
|
if not os.path.isdir(model_dir):
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
def get_sd_rm_tag():
|
|
if shared.sd_model.is_sdxl:
|
|
return ["sd1"]
|
|
elif shared.sd_model.is_sd2:
|
|
return ["sd1", "xl"]
|
|
elif shared.sd_model.is_sd1:
|
|
return ["xl"]
|
|
else:
|
|
return []
|
|
return [
|
|
f for f in sorted(os.listdir(model_dir))
|
|
if f != ".gitkeep" and not any(tag in f for tag in get_sd_rm_tag())
|
|
]
|
|
|
|
|
|
def refresh_models(self, *inputs):
|
|
new_model_list = self.get_model_list()
|
|
dd = inputs[0]
|
|
if dd in new_model_list:
|
|
selected = dd
|
|
elif len(new_model_list) > 0:
|
|
selected = new_model_list[0]
|
|
else:
|
|
selected = None
|
|
return gr.Dropdown.update(choices=new_model_list, value=selected)
|
|
|
|
|
|
def render(self, is_img2img: bool, infotext_fields, paste_field_names):
|
|
elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-"
|
|
with gr.Accordion("AnimateDiff", open=False):
|
|
gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/docs/how-to-use.md#parameters) to read the documentation of each parameter.")
|
|
with gr.Row():
|
|
with gr.Row():
|
|
model_list = self.get_model_list()
|
|
self.params.model = gr.Dropdown(
|
|
choices=model_list,
|
|
value=(self.params.model if self.params.model in model_list else (model_list[0] if len(model_list) > 0 else None)),
|
|
label="Motion module",
|
|
type="value",
|
|
elem_id=f"{elemid_prefix}motion-module",
|
|
)
|
|
refresh_model = ToolButton(value="\U0001f504")
|
|
refresh_model.click(self.refresh_models, self.params.model, self.params.model)
|
|
|
|
self.params.format = gr.CheckboxGroup(
|
|
choices=supported_save_formats,
|
|
label="Save format",
|
|
type="value",
|
|
elem_id=f"{elemid_prefix}save-format",
|
|
value=self.params.format,
|
|
)
|
|
with gr.Row():
|
|
self.params.enable = gr.Checkbox(
|
|
value=self.params.enable, label="Enable AnimateDiff",
|
|
elem_id=f"{elemid_prefix}enable"
|
|
)
|
|
self.params.video_length = gr.Number(
|
|
minimum=0,
|
|
value=self.params.video_length,
|
|
label="Number of frames",
|
|
precision=0,
|
|
elem_id=f"{elemid_prefix}video-length",
|
|
)
|
|
self.params.fps = gr.Number(
|
|
value=self.params.fps, label="FPS", precision=0,
|
|
elem_id=f"{elemid_prefix}fps"
|
|
)
|
|
self.params.loop_number = gr.Number(
|
|
minimum=0,
|
|
value=self.params.loop_number,
|
|
label="Display loop number",
|
|
precision=0,
|
|
elem_id=f"{elemid_prefix}loop-number",
|
|
)
|
|
with gr.Row():
|
|
self.params.closed_loop = gr.Radio(
|
|
choices=["N", "R-P", "R+P", "A"],
|
|
value=self.params.closed_loop,
|
|
label="Closed loop",
|
|
elem_id=f"{elemid_prefix}closed-loop",
|
|
)
|
|
self.params.batch_size = gr.Slider(
|
|
minimum=1,
|
|
maximum=32,
|
|
value=self.params.batch_size,
|
|
label="Context batch size",
|
|
step=1,
|
|
precision=0,
|
|
elem_id=f"{elemid_prefix}batch-size",
|
|
)
|
|
self.params.stride = gr.Number(
|
|
minimum=1,
|
|
value=self.params.stride,
|
|
label="Stride",
|
|
precision=0,
|
|
elem_id=f"{elemid_prefix}stride",
|
|
)
|
|
self.params.overlap = gr.Number(
|
|
minimum=-1,
|
|
value=self.params.overlap,
|
|
label="Overlap",
|
|
precision=0,
|
|
elem_id=f"{elemid_prefix}overlap",
|
|
)
|
|
with gr.Row():
|
|
self.params.interp = gr.Radio(
|
|
choices=["Off", "FILM"],
|
|
label="Frame Interpolation",
|
|
elem_id=f"{elemid_prefix}interp-choice",
|
|
value=self.params.interp
|
|
)
|
|
self.params.interp_x = gr.Number(
|
|
value=self.params.interp_x, label="Interp X", precision=0,
|
|
elem_id=f"{elemid_prefix}interp-x"
|
|
)
|
|
with gr.Accordion("FreeInit Params", open=False):
|
|
gr.Markdown(
|
|
"""
|
|
Adjust to control the smoothness.
|
|
"""
|
|
)
|
|
self.params.freeinit_enable = gr.Checkbox(
|
|
value=self.params.freeinit_enable,
|
|
label="Enable FreeInit",
|
|
elem_id=f"{elemid_prefix}freeinit-enable"
|
|
)
|
|
self.params.freeinit_filter = gr.Dropdown(
|
|
value=self.params.freeinit_filter,
|
|
label="Filter Type",
|
|
info="Default as Butterworth. To fix large inconsistencies, consider using Gaussian.",
|
|
choices=self.filter_type_list,
|
|
interactive=True,
|
|
elem_id=f"{elemid_prefix}freeinit-filter"
|
|
)
|
|
self.params.freeinit_ds = gr.Slider(
|
|
value=self.params.freeinit_ds,
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.125,
|
|
label="d_s",
|
|
info="Stop frequency for spatial dimensions (0.0-1.0)",
|
|
elem_id=f"{elemid_prefix}freeinit-ds"
|
|
)
|
|
self.params.freeinit_dt = gr.Slider(
|
|
value=self.params.freeinit_dt,
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.125,
|
|
label="d_t",
|
|
info="Stop frequency for temporal dimension (0.0-1.0)",
|
|
elem_id=f"{elemid_prefix}freeinit-dt"
|
|
)
|
|
self.params.freeinit_iters = gr.Slider(
|
|
value=self.params.freeinit_iters,
|
|
minimum=2,
|
|
maximum=5,
|
|
step=1,
|
|
label="FreeInit Iterations",
|
|
info="Larger value leads to smoother results & longer inference time.",
|
|
elem_id=f"{elemid_prefix}freeinit-dt",
|
|
)
|
|
self.params.video_source = gr.Video(
|
|
value=self.params.video_source,
|
|
label="Video source",
|
|
)
|
|
def update_fps(video_source):
|
|
if video_source is not None and video_source != '':
|
|
cap = cv2.VideoCapture(video_source)
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
cap.release()
|
|
return fps
|
|
else:
|
|
return int(self.params.fps.value)
|
|
self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps)
|
|
def update_frames(video_source):
|
|
if video_source is not None and video_source != '':
|
|
cap = cv2.VideoCapture(video_source)
|
|
frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
cap.release()
|
|
return frames
|
|
else:
|
|
return int(self.params.video_length.value)
|
|
self.params.video_source.change(update_frames, inputs=self.params.video_source, outputs=self.params.video_length)
|
|
with gr.Row():
|
|
self.params.video_path = gr.Textbox(
|
|
value=self.params.video_path,
|
|
label="Video path",
|
|
elem_id=f"{elemid_prefix}video-path"
|
|
)
|
|
self.params.mask_path = gr.Textbox(
|
|
value=self.params.mask_path,
|
|
label="Mask path",
|
|
visible=False,
|
|
elem_id=f"{elemid_prefix}mask-path"
|
|
)
|
|
if is_img2img:
|
|
with gr.Accordion("I2V Traditional", open=False):
|
|
with gr.Row():
|
|
self.params.latent_power = gr.Slider(
|
|
minimum=0.1,
|
|
maximum=10,
|
|
value=self.params.latent_power,
|
|
step=0.1,
|
|
label="Latent power",
|
|
elem_id=f"{elemid_prefix}latent-power",
|
|
)
|
|
self.params.latent_scale = gr.Slider(
|
|
minimum=1,
|
|
maximum=128,
|
|
value=self.params.latent_scale,
|
|
label="Latent scale",
|
|
elem_id=f"{elemid_prefix}latent-scale"
|
|
)
|
|
self.params.latent_power_last = gr.Slider(
|
|
minimum=0.1,
|
|
maximum=10,
|
|
value=self.params.latent_power_last,
|
|
step=0.1,
|
|
label="Optional latent power for last frame",
|
|
elem_id=f"{elemid_prefix}latent-power-last",
|
|
)
|
|
self.params.latent_scale_last = gr.Slider(
|
|
minimum=1,
|
|
maximum=128,
|
|
value=self.params.latent_scale_last,
|
|
label="Optional latent scale for last frame",
|
|
elem_id=f"{elemid_prefix}latent-scale-last"
|
|
)
|
|
self.params.last_frame = gr.Image(
|
|
label="Optional last frame. Leave it blank if you do not need one.",
|
|
type="pil",
|
|
)
|
|
with gr.Row():
|
|
unload = gr.Button(value="Move motion module to CPU (default if lowvram)")
|
|
remove = gr.Button(value="Remove motion module from any memory")
|
|
unload.click(fn=motion_module.unload)
|
|
remove.click(fn=motion_module.remove)
|
|
|
|
# Set up controls to be copy-pasted using infotext
|
|
fields = self.params.get_param_names(is_img2img)
|
|
infotext_fields.extend((getattr(self.params, field), f"AnimateDiff {field}") for field in fields)
|
|
paste_field_names.extend(f"AnimateDiff {field}" for field in fields)
|
|
|
|
return self.register_unit(is_img2img)
|
|
|
|
|
|
def register_unit(self, is_img2img: bool):
|
|
unit = gr.State(value=AnimateDiffProcess)
|
|
(
|
|
AnimateDiffUiGroup.img2img_submit_button
|
|
if is_img2img
|
|
else AnimateDiffUiGroup.txt2img_submit_button
|
|
).click(
|
|
fn=AnimateDiffProcess,
|
|
inputs=self.params.get_list(is_img2img),
|
|
outputs=unit,
|
|
queue=False,
|
|
)
|
|
return unit
|
|
|
|
|
|
@staticmethod
|
|
def on_after_component(component, **_kwargs):
|
|
elem_id = getattr(component, "elem_id", None)
|
|
|
|
if elem_id == "txt2img_generate":
|
|
AnimateDiffUiGroup.txt2img_submit_button = component
|
|
return
|
|
|
|
if elem_id == "img2img_generate":
|
|
AnimateDiffUiGroup.img2img_submit_button = component
|
|
return
|
|
|
|
if elem_id == "setting_sd_model_checkpoint":
|
|
for group in AnimateDiffUiGroup.animatediff_ui_group:
|
|
component.change( # this step cannot success. I don't know why.
|
|
fn=group.refresh_models,
|
|
inputs=[group.params.model],
|
|
outputs=[group.params.model],
|
|
queue=False,
|
|
)
|
|
return
|
|
|