409 lines
21 KiB
Python
409 lines
21 KiB
Python
# This helper script is responsible for ControlNet/Deforum integration
|
|
# https://github.com/Mikubill/sd-webui-controlnet — controlnet repo
|
|
|
|
import os, sys
|
|
import gradio as gr
|
|
import scripts
|
|
import modules.scripts as scrpts
|
|
from PIL import Image
|
|
import numpy as np
|
|
from modules.processing import process_images
|
|
import importlib
|
|
from .rich import console
|
|
from rich.table import Table
|
|
from rich import box
|
|
from modules import scripts
|
|
|
|
cnet = None
|
|
|
|
def find_controlnet():
|
|
global cnet
|
|
global cnet_import_failure_count
|
|
if cnet is not None:
|
|
return cnet
|
|
try:
|
|
cnet = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code')
|
|
print(f"\033[0;32m*Deforum ControlNet support: enabled*\033[0m")
|
|
return True
|
|
except Exception as e:
|
|
# the tab will be disactivated anyway, so we don't need the error message
|
|
return None
|
|
|
|
gradio_compat = True
|
|
try:
|
|
from distutils.version import LooseVersion
|
|
from importlib_metadata import version
|
|
if LooseVersion(version("gradio")) < LooseVersion("3.10"):
|
|
gradio_compat = False
|
|
except ImportError:
|
|
pass
|
|
|
|
# svgsupports
|
|
svgsupport = False
|
|
try:
|
|
import io
|
|
import base64
|
|
from svglib.svglib import svg2rlg
|
|
from reportlab.graphics import renderPM
|
|
svgsupport = True
|
|
except ImportError:
|
|
pass
|
|
|
|
def ControlnetArgs():
|
|
controlnet_enabled = False
|
|
controlnet_guess_mode = False
|
|
controlnet_rgbbgr_mode = False
|
|
controlnet_lowvram = False
|
|
controlnet_module = "none"
|
|
controlnet_model = "None"
|
|
controlnet_weight = 1.0
|
|
controlnet_guidance_strength = 1.0
|
|
blendFactorMax = "0:(0.35)"
|
|
blendFactorSlope = "0:(0.25)"
|
|
tweening_frames_schedule = "0:(20)"
|
|
color_correction_factor = "0:(0.075)"
|
|
return locals()
|
|
|
|
def setup_controlnet_ui_raw():
|
|
cnet = find_controlnet()
|
|
cn_models = cnet.get_models()
|
|
# since cn preprocessors don't seem to be provided in the API rn, hardcode the names list
|
|
cn_preprocessors = [
|
|
"none",
|
|
"canny",
|
|
"depth",
|
|
"depth_leres",
|
|
"hed",
|
|
"mlsd",
|
|
"normal_map",
|
|
"openpose",
|
|
"openpose_hand",
|
|
"clip_vision",
|
|
"color",
|
|
"pidinet",
|
|
"scribble",
|
|
"fake_scribble",
|
|
"segmentation",
|
|
"binary",
|
|
]
|
|
|
|
# Already under an accordion
|
|
refresh_symbol = '\U0001f504' # 🔄
|
|
switch_values_symbol = '\U000021C5' # ⇅
|
|
model_dropdowns = []
|
|
infotext_fields = []
|
|
# Main part
|
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(variant="tool", **kwargs)
|
|
|
|
def get_block_name(self):
|
|
return "button"
|
|
|
|
# # Copying the main ControlNet widgets while getting rid of static elements such as the scribble pad
|
|
with gr.Row():
|
|
controlnet_enabled = gr.Checkbox(label='Enable', value=False, interactive=True)
|
|
controlnet_guess_mode = gr.Checkbox(label='Guess Mode', value=False, visible=False, interactive=True)
|
|
controlnet_invert_image = gr.Checkbox(label='Invert colors', value=False, visible=False, interactive=True)
|
|
controlnet_rgbbgr_mode = gr.Checkbox(label='RGB to BGR', value=False, visible=False, interactive=True)
|
|
controlnet_lowvram = gr.Checkbox(label='Low VRAM', value=False, visible=False, interactive=True)
|
|
|
|
def refresh_all_models(*inputs):
|
|
cn_models = cnet.get_models(update=True)
|
|
dd = inputs[0]
|
|
selected = dd if dd in cn_models else "None"
|
|
return gr.Dropdown.update(value=selected, choices=cn_models)
|
|
|
|
with gr.Row(visible=False) as cn_mod_row:
|
|
controlnet_module = gr.Dropdown(cn_preprocessors, label=f"Preprocessor", value="none", interactive=True)
|
|
controlnet_model = gr.Dropdown(cn_models, label=f"Model", value="None", interactive=True)
|
|
refresh_models = ToolButton(value=refresh_symbol)
|
|
refresh_models.click(refresh_all_models, controlnet_model, controlnet_model)
|
|
#ctrls += (refresh_models, )
|
|
with gr.Row(visible=False) as cn_weight_row:
|
|
controlnet_weight = gr.Slider(label=f"Weight", value=1.0, minimum=0.0, maximum=2.0, step=.05, interactive=True)
|
|
controlnet_guidance_start = gr.Slider(label="Guidance start", value=0.0, minimum=0.0, maximum=1.0, interactive=True)
|
|
controlnet_guidance_end = gr.Slider(label="Guidance end", value=1.0, minimum=0.0, maximum=1.0, interactive=True)
|
|
#ctrls += (controlnet_module, controlnet_model, controlnet_weight,)
|
|
model_dropdowns.append(controlnet_model)
|
|
|
|
# advanced options
|
|
controlnet_advanced = gr.Column(visible=False)
|
|
with controlnet_advanced:
|
|
controlnet_processor_res = gr.Slider(label="Annotator resolution", value=64, minimum=64, maximum=2048, interactive=False)
|
|
controlnet_threshold_a = gr.Slider(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False)
|
|
controlnet_threshold_b = gr.Slider(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False)
|
|
|
|
if gradio_compat:
|
|
controlnet_module.change(build_sliders, inputs=[controlnet_module], outputs=[controlnet_processor_res, controlnet_threshold_a, controlnet_threshold_b, controlnet_advanced])
|
|
|
|
infotext_fields.extend([
|
|
(controlnet_module, f"ControlNet Preprocessor"),
|
|
(controlnet_model, f"ControlNet Model"),
|
|
(controlnet_weight, f"ControlNet Weight"),
|
|
])
|
|
|
|
with gr.Row(visible=False) as cn_env_row:
|
|
controlnet_resize_mode = gr.Radio(choices=["Envelope (Outer Fit)", "Scale to Fit (Inner Fit)", "Just Resize"], value="Scale to Fit (Inner Fit)", label="Resize Mode", interactive=True)
|
|
|
|
with gr.Row(visible=False) as cn_vid_settings_row:
|
|
controlnet_overwrite_frames = gr.Checkbox(label='Overwrite input frames', value=True, interactive=True)
|
|
controlnet_vid_path = gr.Textbox(value='', label="ControlNet Input Video Path", interactive=True)
|
|
controlnet_mask_vid_path = gr.Textbox(value='', label="ControlNet Mask Video Path", interactive=True)
|
|
|
|
# Video input to be fed into ControlNet
|
|
#input_video_url = gr.Textbox(source='upload', type='numpy', tool='sketch') # TODO
|
|
controlnet_input_video_chosen_file = gr.File(label="ControlNet Video Input", interactive=True, file_count="single", file_types=["video"], elem_id="controlnet_input_video_chosen_file", visible=False)
|
|
controlnet_input_video_mask_chosen_file = gr.File(label="ControlNet Video Mask Input", interactive=True, file_count="single", file_types=["video"], elem_id="controlnet_input_video_mask_chosen_file", visible=False)
|
|
|
|
cn_hide_output_list = [controlnet_guess_mode,controlnet_invert_image,controlnet_rgbbgr_mode,controlnet_lowvram,cn_mod_row,cn_weight_row,cn_env_row,cn_vid_settings_row,controlnet_input_video_chosen_file,controlnet_input_video_mask_chosen_file]
|
|
for cn_output in cn_hide_output_list:
|
|
controlnet_enabled.change(fn=hide_ui_by_cn_status, inputs=controlnet_enabled,outputs=cn_output)
|
|
|
|
return locals()
|
|
|
|
|
|
def setup_controlnet_ui():
|
|
if not find_controlnet():
|
|
gr.HTML("""
|
|
<a style='target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet'>ControlNet not found. Please install it :)</a>
|
|
""", elem_id='controlnet_not_found_html_msg')
|
|
return {}
|
|
|
|
try:
|
|
return setup_controlnet_ui_raw()
|
|
except Exception as e:
|
|
print(f"'ControlNet UI setup failed due to '{e}'!")
|
|
gr.HTML(f"""
|
|
Failed to setup ControlNet UI, check the reason in your commandline log. Please, downgrade your CN extension to <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/archive/c9340671d6d59e5a79fc404f78f747f969f87374.zip'>c9340671d6d59e5a79fc404f78f747f969f87374</a> or report the problem <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/issues'>here</a>.
|
|
""", elem_id='controlnet_not_found_html_msg')
|
|
return {}
|
|
|
|
|
|
def controlnet_component_names():
|
|
if not find_controlnet():
|
|
return []
|
|
|
|
controlnet_args_names = str(r'''controlnet_input_video_chosen_file, controlnet_input_video_mask_chosen_file,
|
|
controlnet_overwrite_frames,controlnet_vid_path,controlnet_mask_vid_path,
|
|
controlnet_enabled, controlnet_guess_mode, controlnet_invert_image, controlnet_rgbbgr_mode, controlnet_lowvram,
|
|
controlnet_module, controlnet_model,
|
|
controlnet_weight, controlnet_guidance_start, controlnet_guidance_end,
|
|
controlnet_processor_res,
|
|
controlnet_threshold_a, controlnet_threshold_b, controlnet_resize_mode'''
|
|
).replace("\n", "").replace("\r", "").replace(" ", "").split(',')
|
|
|
|
return controlnet_args_names
|
|
|
|
def controlnet_infotext():
|
|
return """Requires the <a style='color:SteelBlue;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet'>ControlNet</a> extension to be installed.</p>
|
|
<p style="margin-top:0.2em">
|
|
*Work In Progress*. All params below are going to be keyframable at some point. If you want to speedup the integration, *especially* if you want to bring new features sooner, <a style='color:Violet;' target='_blank' href='https://github.com/deforum-art/deforum-for-automatic1111-webui/'>join Deforum's development</a>. 😉
|
|
</p>
|
|
<p">
|
|
If you previously downgraded the CN extension to use it in Deforum, upgrade it to the latest version for the API communication to work; also, you don't need to store your CN models in deforum/models anymore. We got complaints that fp16 models sometimes don't work for an unknown reason, so use <a style='color:DarkGreen;' target='_blank' href='https://huggingface.co/lllyasviel/ControlNet/tree/main/models'>the full precision ones</a> until it's fixed. Note that CN's API may breakingly change at any time. The recommended CN version is <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/archive/c9340671d6d59e5a79fc404f78f747f969f87374.zip'>c9340671d6d59e5a79fc404f78f747f969f87374</a>. If Deforum crashes due to CN updates, go <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/issues'>here</a> and report your problem.
|
|
</p>
|
|
"""
|
|
|
|
def is_controlnet_enabled(controlnet_args):
|
|
return 'controlnet_enabled' in vars(controlnet_args) and controlnet_args.controlnet_enabled
|
|
|
|
def process_with_controlnet(p, args, anim_args, loop_args, controlnet_args, root, is_img2img = True, frame_idx = 1):
|
|
cnet = find_controlnet()
|
|
|
|
controlnet_frame_path = os.path.join(args.outdir, 'controlnet_inputframes', f"{frame_idx:09}.jpg")
|
|
controlnet_mask_frame_path = os.path.join(args.outdir, 'controlnet_maskframes', f"{frame_idx:09}.jpg")
|
|
|
|
print(f'Reading ControlNet base frame {frame_idx} at {controlnet_frame_path}')
|
|
print(f'Reading ControlNet mask frame {frame_idx} at {controlnet_mask_frame_path}')
|
|
|
|
cn_mask_np = None
|
|
cn_image_np = None
|
|
|
|
if not os.path.exists(controlnet_frame_path) and not os.path.exists(controlnet_mask_frame_path):
|
|
print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m')
|
|
return
|
|
|
|
if os.path.exists(controlnet_frame_path):
|
|
cn_image_np = np.array(Image.open(controlnet_frame_path).convert("RGB")).astype('uint8')
|
|
|
|
if os.path.exists(controlnet_mask_frame_path):
|
|
cn_mask_np = np.array(Image.open(controlnet_mask_frame_path).convert("RGB")).astype('uint8')
|
|
|
|
table = Table(title="ControlNet params",padding=0, box=box.ROUNDED)
|
|
|
|
# TODO: auto infer the names and the values for the table
|
|
field_names = []
|
|
field_names += ["module", "model", "weight", "inv", "guide_start", "guide_end", "guess", "resize", "rgb_bgr", "proc res", "thr a", "thr b"]
|
|
for field_name in field_names:
|
|
table.add_column(field_name, justify="center")
|
|
|
|
cn_model_name = str(controlnet_args.controlnet_model)
|
|
|
|
rows = []
|
|
rows += [controlnet_args.controlnet_module, cn_model_name[len('control_'):] if 'control_' in cn_model_name else cn_model_name, controlnet_args.controlnet_weight, controlnet_args.controlnet_invert_image, controlnet_args.controlnet_guidance_start, controlnet_args.controlnet_guidance_end, controlnet_args.controlnet_guess_mode, controlnet_args.controlnet_resize_mode, controlnet_args.controlnet_rgbbgr_mode, controlnet_args.controlnet_processor_res, controlnet_args.controlnet_threshold_a, controlnet_args.controlnet_threshold_b]
|
|
rows = [str(x) for x in rows]
|
|
|
|
table.add_row(*rows)
|
|
|
|
console.print(table)
|
|
|
|
p.scripts = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
|
|
|
cnu = {
|
|
"enabled":True,
|
|
"module":controlnet_args.controlnet_module,
|
|
"model":controlnet_args.controlnet_model,
|
|
"weight":controlnet_args.controlnet_weight,
|
|
"image":{'image': cn_image_np, 'mask': cn_mask_np} if cn_mask_np is not None else cn_image_np,
|
|
"invert_image":controlnet_args.controlnet_invert_image,
|
|
"guess_mode":controlnet_args.controlnet_guess_mode,
|
|
"resize_mode":controlnet_args.controlnet_resize_mode,
|
|
"rgbbgr_mode":controlnet_args.controlnet_rgbbgr_mode,
|
|
"low_vram":controlnet_args.controlnet_lowvram,
|
|
"processor_res":controlnet_args.controlnet_processor_res,
|
|
"threshold_a":controlnet_args.controlnet_threshold_a,
|
|
"threshold_b":controlnet_args.controlnet_threshold_b,
|
|
"guidance_start":controlnet_args.controlnet_guidance_start,
|
|
"guidance_end":controlnet_args.controlnet_guidance_end,
|
|
}
|
|
|
|
p.script_args = (
|
|
cnu["enabled"],
|
|
cnu["module"],
|
|
cnu["model"],
|
|
cnu["weight"],
|
|
cnu["invert_image"],
|
|
cnu["image"],
|
|
cnu["guess_mode"],
|
|
cnu["resize_mode"],
|
|
cnu["rgbbgr_mode"],
|
|
cnu["low_vram"],
|
|
cnu["processor_res"],
|
|
cnu["threshold_a"],
|
|
cnu["threshold_b"],
|
|
cnu["guidance_start"],
|
|
cnu["guidance_end"],
|
|
)
|
|
|
|
cn_units = [
|
|
cnet.ControlNetUnit(**cnu),
|
|
]
|
|
|
|
cnet.update_cn_script_in_processing(p, cn_units, is_img2img=is_img2img, is_ui=False)
|
|
|
|
import pathlib
|
|
from .video_audio_utilities import vid2frames
|
|
|
|
def unpack_controlnet_vids(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root):
|
|
if controlnet_args.controlnet_input_video_chosen_file is not None and len(controlnet_args.controlnet_input_video_chosen_file.name) > 0 or len(controlnet_args.controlnet_vid_path) > 0:
|
|
print(f'Unpacking ControlNet base video')
|
|
# create a folder for the video input frames to live in
|
|
mask_in_frame_path = os.path.join(args.outdir, 'controlnet_inputframes')
|
|
os.makedirs(mask_in_frame_path, exist_ok=True)
|
|
|
|
# save the video frames from mask video
|
|
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...")
|
|
vid2frames(video_path=controlnet_args.controlnet_vid_path if len(controlnet_args.controlnet_vid_path) > 0 else controlnet_args.controlnet_input_video_chosen_file.name, video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=controlnet_args.controlnet_overwrite_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame, numeric_files_output=True)
|
|
|
|
print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}")
|
|
print(f'ControlNet base video unpacked!')
|
|
|
|
if controlnet_args.controlnet_input_video_mask_chosen_file is not None and len(controlnet_args.controlnet_input_video_mask_chosen_file.name) > 0 or len(controlnet_args.controlnet_mask_vid_path) > 0:
|
|
print(f'Unpacking ControlNet video mask')
|
|
# create a folder for the video input frames to live in
|
|
mask_in_frame_path = os.path.join(args.outdir, 'controlnet_maskframes')
|
|
os.makedirs(mask_in_frame_path, exist_ok=True)
|
|
|
|
# save the video frames from mask video
|
|
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...")
|
|
vid2frames(video_path=controlnet_args.controlnet_mask_vid_path if len(controlnet_args.controlnet_mask_vid_path) > 0 else controlnet_args.controlnet_input_video_mask_chosen_file.name, video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=controlnet_args.controlnet_overwrite_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame, numeric_files_output=True)
|
|
|
|
print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}")
|
|
print(f'ControlNet video mask unpacked!')
|
|
|
|
def hide_ui_by_cn_status(choice):
|
|
return gr.update(visible=True) if choice else gr.update(visible=False)
|
|
|
|
def build_sliders(module):
|
|
if module == "canny":
|
|
return [
|
|
gr.update(label="Annotator resolution", value=512, minimum=64, maximum=2048, step=1, interactive=True),
|
|
gr.update(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1, interactive=True),
|
|
gr.update(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1, interactive=True),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "mlsd": #Hough
|
|
return [
|
|
gr.update(label="Hough Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True),
|
|
gr.update(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01, interactive=True),
|
|
gr.update(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01, interactive=True),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module in ["hed", "fake_scribble"]:
|
|
return [
|
|
gr.update(label="HED Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module in ["openpose", "openpose_hand", "segmentation"]:
|
|
return [
|
|
gr.update(label="Annotator Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "depth":
|
|
return [
|
|
gr.update(label="Midas Resolution", minimum=64, maximum=2048, value=384, step=1, interactive=True),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module in ["depth_leres", "depth_leres_boost"]:
|
|
return [
|
|
gr.update(label="LeReS Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True),
|
|
gr.update(label="Remove Near %", value=0, minimum=0, maximum=100, step=0.1, interactive=True),
|
|
gr.update(label="Remove Background %", value=0, minimum=0, maximum=100, step=0.1, interactive=True),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "normal_map":
|
|
return [
|
|
gr.update(label="Normal Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True),
|
|
gr.update(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01, interactive=True),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "binary":
|
|
return [
|
|
gr.update(label="Annotator resolution", value=512, minimum=64, maximum=2048, step=1, interactive=True),
|
|
gr.update(label="Binary threshold", minimum=0, maximum=255, value=0, step=1, interactive=True),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "color":
|
|
return [
|
|
gr.update(label="Annotator Resolution", value=512, minimum=64, maximum=2048, step=8, interactive=True),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
elif module == "none":
|
|
return [
|
|
gr.update(label="Normal Resolution", value=64, minimum=64, maximum=2048, interactive=False),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=False)
|
|
]
|
|
else:
|
|
return [
|
|
gr.update(label="Annotator resolution", value=512, minimum=64, maximum=2048, step=1, interactive=True),
|
|
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
|
gr.update(visible=True)
|
|
]
|
|
|