sd-webui-animatediff/scripts/animatediff_ui.py

316 lines
12 KiB
Python

import os
import cv2
import gradio as gr
from scripts.animatediff_mm import mm_animatediff as motion_module
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
class AnimateDiffProcess:
def __init__(
self,
model="mm_sd_v15_v2.ckpt",
enable=False,
video_length=0,
fps=8,
loop_number=0,
closed_loop=False,
batch_size=16,
stride=1,
overlap=-1,
format=["GIF", "PNG"],
interp='Off',
interp_x=10,
reverse=[],
video_source=None,
video_path='',
latent_power=1,
latent_scale=32,
last_frame=None,
latent_power_last=1,
latent_scale_last=32,
):
self.model = model
self.enable = enable
self.video_length = video_length
self.fps = fps
self.loop_number = loop_number
self.closed_loop = closed_loop
self.batch_size = batch_size
self.stride = stride
self.overlap = overlap
self.format = format
self.interp = interp
self.interp_x = interp_x
self.reverse = reverse
self.video_source = video_source
self.video_path = video_path
self.latent_power = latent_power
self.latent_scale = latent_scale
self.last_frame = last_frame
self.latent_power_last = latent_power_last
self.latent_scale_last = latent_scale_last
def get_list(self, is_img2img: bool):
list_var = list(vars(self).values())
if not is_img2img:
list_var = list_var[:-5]
return list_var
def _check(self):
assert (
self.video_length >= 0 and self.fps > 0
), "Video length and FPS should be positive."
assert not set(["GIF", "MP4", "PNG"]).isdisjoint(
self.format
), "At least one saving format should be selected."
def set_p(self, p):
self._check()
if self.video_length < self.batch_size:
p.batch_size = self.batch_size
else:
p.batch_size = self.video_length
if self.video_length == 0:
self.video_length = p.batch_size
self.video_default = True
else:
self.video_default = False
if self.overlap == -1:
self.overlap = self.batch_size // 4
if "PNG" not in self.format:
p.do_not_save_samples = True
class AnimateDiffUiGroup:
txt2img_submit_button = None
img2img_submit_button = None
def __init__(self):
self.params = AnimateDiffProcess()
def render(self, is_img2img: bool, model_dir: str):
if not os.path.isdir(model_dir):
os.mkdir(model_dir)
elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-"
model_list = [f for f in os.listdir(model_dir) if f != ".gitkeep"]
with gr.Accordion("AnimateDiff", open=False):
with gr.Row():
def refresh_models(*inputs):
new_model_list = [
f for f in os.listdir(model_dir) if f != ".gitkeep"
]
dd = inputs[0]
if dd in new_model_list:
selected = dd
elif len(new_model_list) > 0:
selected = new_model_list[0]
else:
selected = None
return gr.Dropdown.update(choices=new_model_list, value=selected)
self.params.model = gr.Dropdown(
choices=model_list,
value=(self.params.model if self.params.model in model_list else None),
label="Motion module",
type="value",
tooltip="Choose which motion module will be injected into the generation process.",
elem_id=f"{elemid_prefix}motion-module",
)
refresh_model = ToolButton(value="\U0001f504")
refresh_model.click(
refresh_models, self.params.model, self.params.model
)
with gr.Row():
self.params.enable = gr.Checkbox(
value=self.params.enable, label="Enable AnimateDiff",
elem_id=f"{elemid_prefix}enable"
)
self.params.video_length = gr.Number(
minimum=0,
value=self.params.video_length,
label="Number of frames",
precision=0,
tooltip="Total length of video in frames.",
elem_id=f"{elemid_prefix}video-length",
)
self.params.fps = gr.Number(
value=self.params.fps, label="FPS", precision=0,
tooltip="How many frames per second the gif will run.",
elem_id=f"{elemid_prefix}fps"
)
self.params.loop_number = gr.Number(
minimum=0,
value=self.params.loop_number,
label="Display loop number",
precision=0,
tooltip="How many times the animation will loop, a value of 0 will loop forever.",
elem_id=f"{elemid_prefix}loop-number",
)
with gr.Row():
self.params.closed_loop = gr.Checkbox(
value=self.params.closed_loop,
label="Closed loop",
tooltip="If enabled, will try to make the last frame the same as the first frame.",
elem_id=f"{elemid_prefix}closed-loop",
)
self.params.batch_size = gr.Slider(
minimum=1,
maximum=32,
value=self.params.batch_size,
label="Context batch size",
step=1,
precision=0,
elem_id=f"{elemid_prefix}batch-size",
)
self.params.stride = gr.Number(
minimum=1,
value=self.params.stride,
label="Stride",
precision=0,
tooltip="",
elem_id=f"{elemid_prefix}stride",
)
self.params.overlap = gr.Number(
minimum=-1,
value=self.params.overlap,
label="Overlap",
precision=0,
tooltip="Number of frames to overlap in context.",
elem_id=f"{elemid_prefix}overlap",
)
with gr.Row():
self.params.format = gr.CheckboxGroup(
choices=["GIF", "MP4", "PNG", "TXT"],
label="Save",
type="value",
tooltip="Which formats the animation should be saved in",
elem_id=f"{elemid_prefix}save-format",
value=self.params.format,
)
self.params.reverse = gr.CheckboxGroup(
choices=["Add Reverse Frame", "Remove head", "Remove tail"],
label="Reverse",
type="index",
tooltip="Reverse the resulting animation, remove the first and/or last frame from duplication.",
elem_id=f"{elemid_prefix}reverse",
value=self.params.reverse
)
with gr.Row():
self.params.interp = gr.Radio(
choices=["Off", "FILM"],
label="Frame Interpolation",
tooltip="Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension.",
elem_id=f"{elemid_prefix}interp-choice",
value=self.params.interp
)
self.params.interp_x = gr.Number(
value=self.params.interp_x, label="Interp X", precision=0,
tooltip="Replace each input frame with X interpolated output frames.",
elem_id=f"{elemid_prefix}interp-x"
)
self.params.video_source = gr.Video(
value=self.params.video_source,
label="Video source",
)
def update_fps(video_source):
if video_source is not None and video_source != '':
cap = cv2.VideoCapture(video_source)
fps = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()
return fps
else:
return self.params.fps
self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps)
self.params.video_path = gr.Textbox(
value=self.params.video_path,
label="Video path",
tooltip="Paste path to video file.",
elem_id=f"{elemid_prefix}video-path"
)
if is_img2img:
with gr.Row():
self.params.latent_power = gr.Slider(
minimum=0.1,
maximum=10,
value=self.params.latent_power,
step=0.1,
label="Latent power",
tooltip="",
elem_id=f"{elemid_prefix}latent-power",
)
self.params.latent_scale = gr.Slider(
minimum=1,
maximum=128,
value=self.params.latent_scale,
label="Latent scale",
tooltip="",
elem_id=f"{elemid_prefix}latent-scale"
)
self.params.latent_power_last = gr.Slider(
minimum=0.1,
maximum=10,
value=self.params.latent_power_last,
step=0.1,
label="Optional latent power for last frame",
tooltip="",
elem_id=f"{elemid_prefix}latent-power-last",
)
self.params.latent_scale_last = gr.Slider(
minimum=1,
maximum=128,
value=self.params.latent_scale_last,
label="Optional latent scale for last frame",
tooltip="",
elem_id=f"{elemid_prefix}latent-scale-last"
)
self.params.last_frame = gr.Image(
label="[Experiment] Optional last frame. Leave it blank if you do not need one.",
type="pil",
)
with gr.Row():
unload = gr.Button(
value="Move motion module to CPU (default if lowvram)"
)
remove = gr.Button(value="Remove motion module from any memory")
unload.click(fn=motion_module.unload)
remove.click(fn=motion_module.remove)
return self.register_unit(is_img2img)
def register_unit(self, is_img2img: bool):
unit = gr.State()
(
AnimateDiffUiGroup.img2img_submit_button
if is_img2img
else AnimateDiffUiGroup.txt2img_submit_button
).click(
fn=AnimateDiffProcess,
inputs=self.params.get_list(is_img2img),
outputs=unit,
queue=False,
)
return unit
@staticmethod
def on_after_component(component, **_kwargs):
elem_id = getattr(component, "elem_id", None)
if elem_id == "txt2img_generate":
AnimateDiffUiGroup.txt2img_submit_button = component
return
if elem_id == "img2img_generate":
AnimateDiffUiGroup.img2img_submit_button = component
return