xyz plot support (#431)

Co-authored-by: zappityzap <zappityzap@proton.me>
pull/444/head
zappityzap 2024-02-17 11:47:13 -08:00 committed by GitHub
parent 2573b2b513
commit f279033b91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 151 additions and 3 deletions

View File

@ -14,6 +14,7 @@ from scripts.animatediff_prompt import AnimateDiffPromptSchedule
from scripts.animatediff_output import AnimateDiffOutput
from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup, supported_save_formats
from scripts.animatediff_infotext import update_infotext, infotext_pasted
from scripts.animatediff_xyz import patch_xyz, xyz_attrs
script_dir = scripts.basedir()
motion_module.set_script_dir(script_dir)
@ -52,6 +53,11 @@ class AnimateDiffScript(scripts.Script):
if p.is_api and isinstance(params, dict):
self.ad_params = AnimateDiffProcess(**params)
params = self.ad_params
# apply XYZ settings
params.apply_xyz()
xyz_attrs.clear()
if params.enable:
logger.info("AnimateDiff process start.")
params.set_p(p)
@ -290,7 +296,9 @@ def on_ui_settings():
section=s3_selection,
),
)
patch_xyz()
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_after_component(AnimateDiffUiGroup.on_after_component)
script_callbacks.on_before_ui(AnimateDiffUiGroup.on_before_ui)

View File

@ -1,6 +1,7 @@
import base64
import datetime
from pathlib import Path
import traceback
import imageio.v3 as imageio
import numpy as np
@ -18,7 +19,9 @@ from scripts.animatediff_ui import AnimateDiffProcess
class AnimateDiffOutput:
def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
video_paths = []
logger.info("Merging images into GIF.")
first_frames = []
from_xyz = any("xyz_grid" in frame.filename for frame in traceback.extract_stack())
logger.info(f"Saving output formats: {', '.join(params.format)}")
date = datetime.datetime.now().strftime('%Y-%m-%d')
output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{date}")
output_dir.mkdir(parents=True, exist_ok=True)
@ -27,7 +30,9 @@ class AnimateDiffOutput:
# frame interpolation replaces video_list with interpolated frames
# so make a copy instead of a slice (reference), to avoid modifying res
frame_list = [image.copy() for image in res.images[i : i + params.video_length]]
if from_xyz:
first_frames.append(res.images[i].copy())
seq = images.get_next_sequence_number(output_dir, "")
filename_suffix = f"-{params.request_id}" if params.request_id else ""
filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}{filename_suffix}"
@ -43,6 +48,9 @@ class AnimateDiffOutput:
res.images = video_paths if not p.is_api else (self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else []))
# replace results with first frame of each video so xyz grid draws correctly
if from_xyz:
res.images = first_frames
def _add_reverse(self, params: AnimateDiffProcess, frame_list: list):
if params.video_length <= params.batch_size and params.closed_loop in ['A']:

View File

@ -9,6 +9,8 @@ from modules.processing import StableDiffusionProcessing
from scripts.animatediff_mm import mm_animatediff as motion_module
from scripts.animatediff_i2ibatch import animatediff_i2ibatch
from scripts.animatediff_lcm import AnimateDiffLCM
from scripts.animatediff_logger import logger_animatediff as logger
from scripts.animatediff_xyz import xyz_attrs
supported_save_formats = ["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"]
@ -145,6 +147,10 @@ class AnimateDiffProcess:
if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", False):
p.do_not_save_samples = True
def apply_xyz(self):
for k, v in xyz_attrs.items():
setattr(self, k, v)
class AnimateDiffUiGroup:
txt2img_submit_button = None

126
scripts/animatediff_xyz.py Normal file
View File

@ -0,0 +1,126 @@
import sys
from types import ModuleType
from typing import Optional
from modules import scripts
from scripts.animatediff_logger import logger_animatediff as logger
xyz_attrs: dict = {}
def patch_xyz():
xyz_module = find_xyz_module()
if xyz_module is None:
logger.warning("XYZ module not found.")
return
MODULE = "[AnimateDiff]"
xyz_module.axis_options.extend([
xyz_module.AxisOption(
label=f"{MODULE} Enabled",
type=str_to_bool,
apply=apply_state("enable"),
choices=choices_bool),
xyz_module.AxisOption(
label=f"{MODULE} Motion Module",
type=str,
apply=apply_state("model")),
xyz_module.AxisOption(
label=f"{MODULE} Video length",
type=int_or_float,
apply=apply_state("video_length")),
xyz_module.AxisOption(
label=f"{MODULE} FPS",
type=int_or_float,
apply=apply_state("fps")),
xyz_module.AxisOption(
label=f"{MODULE} Use main seed",
type=str_to_bool,
apply=apply_state("use_main_seed"),
choices=choices_bool),
xyz_module.AxisOption(
label=f"{MODULE} Closed loop",
type=str,
apply=apply_state("closed_loop"),
choices=lambda: ["N", "R-P", "R+P", "A"]),
xyz_module.AxisOption(
label=f"{MODULE} Batch size",
type=int_or_float,
apply=apply_state("batch_size")),
xyz_module.AxisOption(
label=f"{MODULE} Stride",
type=int_or_float,
apply=apply_state("stride")),
xyz_module.AxisOption(
label=f"{MODULE} Overlap",
type=int_or_float,
apply=apply_state("overlap")),
xyz_module.AxisOption(
label=f"{MODULE} Interp",
type=str_to_bool,
apply=apply_state("interp"),
choices=choices_bool),
xyz_module.AxisOption(
label=f"{MODULE} Interp X",
type=int_or_float,
apply=apply_state("interp_x")),
xyz_module.AxisOption(
label=f"{MODULE} Video path",
type=str,
apply=apply_state("video_path")),
xyz_module.AxisOptionImg2Img(
label=f"{MODULE} Latent power",
type=int_or_float,
apply=apply_state("latent_power")),
xyz_module.AxisOptionImg2Img(
label=f"{MODULE} Latent scale",
type=int_or_float,
apply=apply_state("latent_scale")),
xyz_module.AxisOptionImg2Img(
label=f"{MODULE} Latent power last",
type=int_or_float,
apply=apply_state("latent_power_last")),
xyz_module.AxisOptionImg2Img(
label=f"{MODULE} Latent scale last",
type=int_or_float,
apply=apply_state("latent_scale_last")),
])
def apply_state(k, key_map=None):
def callback(_p, v, _vs):
if key_map is not None:
v = key_map[v]
xyz_attrs[k] = v
return callback
def str_to_bool(string):
string = str(string)
if string in ["None", ""]:
return None
elif string.lower() in ["true", "1"]:
return True
elif string.lower() in ["false", "0"]:
return False
else:
raise ValueError(f"Could not convert string to boolean: {string}")
def int_or_float(string):
try:
return int(string)
except ValueError:
return float(string)
def choices_bool():
return ["False", "True"]
def find_xyz_module() -> Optional[ModuleType]:
for data in scripts.scripts_data:
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
return data.module
return None