improve api (#468)

pull/469/head
Chengsong Zhang 2024-03-15 06:34:39 -05:00 committed by GitHub
parent 6314c17c2f
commit 74a2ae1be9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 15 deletions

View File

@ -27,6 +27,16 @@ def on_ui_settings():
section=section
).needs_restart()
)
shared.opts.add_option(
"animatediff_save_to_custom",
shared.OptionInfo(
True,
"Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/{date} "
"instead of stable-diffusion-webui/outputs/{ txt|img }2img-images/{date}/.",
gr.Checkbox,
section=section
)
)
shared.opts.add_option(
"animatediff_frame_extract_path",
shared.OptionInfo(
@ -47,12 +57,12 @@ def on_ui_settings():
)
)
shared.opts.add_option(
"animatediff_save_to_custom",
"animatediff_default_frame_extract_method",
shared.OptionInfo(
True,
"Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/{date} "
"instead of stable-diffusion-webui/outputs/{ txt|img }2img-images/{date}/.",
gr.Checkbox,
"ffmpeg",
"Default frame extraction method",
gr.Radio,
{"choices": ["ffmpeg", "opencv"]},
section=section
)
)

View File

@ -59,15 +59,13 @@ def get_controlnet_units(p: StableDiffusionProcessing):
if p.is_api and len(cn_units) > 0 and isinstance(cn_units[0], dict):
from scripts import external_code
from scripts.batch_hijack import InputMode
cn_units = external_code.get_all_units_in_processing(p)
for cn_unit in cn_units:
setattr(cn_unit, "input_mode", InputMode.BATCH)
setattr(cn_unit, "batch_images", None)
setattr(cn_unit, "batch_mask_dir", None)
setattr(cn_unit, "batch_input_gallery", None)
setattr(cn_unit, "batch_modifiers", [])
setattr(cn_unit, "animatediff_batch", True)
p.script_args[script.args_from:script.args_to] = cn_units
cn_units_dataclass = external_code.get_all_units_in_processing(p)
for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass):
if cn_unit_dataclass.image is None:
cn_unit_dataclass.input_mode = InputMode.BATCH
cn_unit_dataclass.batch_images = cn_unit_dict.get("batch_images", None)
p.script_args[script.args_from:script.args_to] = cn_units_dataclass
return [x for x in cn_units if x.enabled] if not p.is_api else cn_units
return []
@ -115,7 +113,10 @@ def extract_frames_from_video(params):
params.video_path = f"{data_path}/tmp/animatediff-frames"
params.video_path = os.path.join(params.video_path, f"{Path(params.video_source).stem}-{generate_random_hash()}")
try:
ffmpeg_extract_frames(params.video_source, params.video_path)
if shared.opts.data.get("animatediff_default_frame_extract_method", "ffmpeg") == "opencv":
cv2_extract_frames(params.video_source, params.video_path)
else:
ffmpeg_extract_frames(params.video_source, params.video_path)
except Exception as e:
logger.error(f"[AnimateDiff] Error extracting frames via ffmpeg: {e}, fall back to OpenCV.")
cv2_extract_frames(params.video_source, params.video_path)