Add Stable LoRA support
parent
3f4a109a69
commit
2dad959835
|
|
@ -22,6 +22,7 @@ from t2v_helpers.args import get_outdir, process_args
|
|||
import t2v_helpers.args as t2v_helpers_args
|
||||
from modules import shared, sd_hijack, lowvram
|
||||
from modules.shared import opts, devices, state
|
||||
from stable_lora.scripts.lora_webui import gr_inputs_list, StableLoraScriptInstance
|
||||
import os
|
||||
|
||||
pipe = None
|
||||
|
|
@ -29,7 +30,7 @@ pipe = None
|
|||
def setup_pipeline(model_name):
|
||||
return TextToVideoSynthesis(get_model_location(model_name))
|
||||
|
||||
def process_modelscope(args_dict):
|
||||
def process_modelscope(args_dict, extra_args=None):
|
||||
args, video_args = process_args(args_dict)
|
||||
|
||||
global pipe
|
||||
|
|
@ -63,6 +64,11 @@ def process_modelscope(args_dict):
|
|||
if pipe is None or pipe is not None and args.model is not None and get_model_location(args.model) != pipe.model_dir:
|
||||
pipe = setup_pipeline(args.model)
|
||||
|
||||
#TODO Wrap this in a list so that we can process this for future extensions.
|
||||
stable_lora_processor = StableLoraScriptInstance
|
||||
stable_lora_args = stable_lora_processor.process_extension_args(all_args=extra_args)
|
||||
stable_lora_processor.process(pipe, *stable_lora_args)
|
||||
|
||||
pipe.keep_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None'
|
||||
|
||||
device = devices.get_optimal_device()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,220 @@
|
|||
import loralib as loralb
|
||||
import torch
|
||||
import glob
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from types import SimpleNamespace
|
||||
from safetensors import safe_open
|
||||
from einops import rearrange
|
||||
import gradio as gr
|
||||
import os
|
||||
import json
|
||||
|
||||
from modules import images, script_callbacks
|
||||
from modules.shared import opts, cmd_opts, state, cmd_opts, sd_model
|
||||
from modules.sd_models import read_state_dict
|
||||
from stable_lora.stable_utils.lora_processor import StableLoraProcessor
|
||||
from t2v_helpers.extensions_utils import Text2VideoExtension
|
||||
|
||||
EXTENSION_TITLE = "Stable LoRA"
|
||||
EXTENSION_NAME = EXTENSION_TITLE.replace(' ', '_').lower()
|
||||
|
||||
gr_inputs_list = [
|
||||
"lora_file_selection",
|
||||
"lora_alpha",
|
||||
"refresh_button",
|
||||
"use_bias",
|
||||
"use_linear",
|
||||
"use_conv",
|
||||
"use_emb",
|
||||
"use_time",
|
||||
"use_multiplier"
|
||||
]
|
||||
|
||||
gr_inputs_dict = {v: v for v in gr_inputs_list}
|
||||
GradioInputsIds = SimpleNamespace(**gr_inputs_dict)
|
||||
|
||||
class StableLoraScript(Text2VideoExtension, StableLoraProcessor):
|
||||
|
||||
def __init__(self):
|
||||
StableLoraProcessor.__init__(self)
|
||||
Text2VideoExtension.__init__(self, EXTENSION_NAME, EXTENSION_TITLE)
|
||||
self.device = 'cuda'
|
||||
self.dtype = torch.float16
|
||||
|
||||
def title(self):
|
||||
return EXTENSION_TITLE
|
||||
|
||||
def refresh_models(self, *args):
|
||||
paths_with_metadata, lora_names = self.get_lora_files()
|
||||
self.lora_files = paths_with_metadata.copy()
|
||||
|
||||
return gr.Dropdown.update(value=[], choices=lora_names)
|
||||
|
||||
def ui(self):
|
||||
paths_with_metadata, lora_names = self.get_lora_files()
|
||||
self.lora_files = paths_with_metadata.copy()
|
||||
REPOSITORY_LINK = "https://github.com/ExponentialML/Text-To-Video-Finetuning"
|
||||
|
||||
with gr.Accordion(label=EXTENSION_TITLE, open=False) as stable_lora_section:
|
||||
with gr.Blocks(analytics_enabled=False):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.HTML("<h2>Load a Trained LoRA File.</h2>")
|
||||
gr.HTML(
|
||||
"""
|
||||
<h3 style='color: crimson; font-weight: bold;'>
|
||||
Only Stable LoRA files are supported.
|
||||
</h3>
|
||||
"""
|
||||
)
|
||||
gr.HTML(f"""
|
||||
<a href='{REPOSITORY_LINK}'>
|
||||
To train a Stable LoRA file, use the finetune repository by clicking here.
|
||||
</a>"""
|
||||
)
|
||||
lora_files_selection = gr.Dropdown(
|
||||
label="Available Models",
|
||||
elem_id=GradioInputsIds.lora_file_selection,
|
||||
choices=lora_names,
|
||||
value=[],
|
||||
multiselect=True,
|
||||
)
|
||||
lora_alpha = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=1,
|
||||
step=0.05,
|
||||
elem_id=GradioInputsIds.lora_alpha,
|
||||
label="LoRA Weight"
|
||||
)
|
||||
refresh_button = gr.Button(
|
||||
value="Refresh Models",
|
||||
elem_id=GradioInputsIds.refresh_button
|
||||
)
|
||||
refresh_button.click(
|
||||
self.refresh_models,
|
||||
lora_files_selection,
|
||||
lora_files_selection
|
||||
)
|
||||
with gr.Accordion(label="Advanced Settings", open=False, visible=False):
|
||||
with gr.Column():
|
||||
use_bias = gr.Checkbox(label="Enable Bias", elem_id=GradioInputsIds.use_bias, value=lambda: True)
|
||||
use_linear = gr.Checkbox(label="Enable Linears", elem_id=GradioInputsIds.use_linear, value=lambda: True)
|
||||
use_conv = gr.Checkbox(label="Enable Convolutions", elem_id=GradioInputsIds.use_conv, value=lambda: True)
|
||||
use_emb = gr.Checkbox(label="Enable Embeddings", elem_id=GradioInputsIds.use_emb, value=lambda: True)
|
||||
use_time = gr.Checkbox(label="Enable Time", elem_id=GradioInputsIds.use_time, value=lambda: True)
|
||||
with gr.Column():
|
||||
use_multiplier = gr.Number(
|
||||
label="Alpha Multiplier",
|
||||
elem_id=GradioInputsIds.use_multiplier,
|
||||
value=1,
|
||||
)
|
||||
|
||||
|
||||
return self.return_ui_inputs(
|
||||
return_args=[
|
||||
lora_files_selection,
|
||||
lora_alpha,
|
||||
use_bias,
|
||||
use_linear,
|
||||
use_conv,
|
||||
use_emb,
|
||||
use_multiplier,
|
||||
use_time
|
||||
]
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def process(
|
||||
self,
|
||||
p,
|
||||
lora_files_selection,
|
||||
lora_alpha,
|
||||
use_bias,
|
||||
use_linear,
|
||||
use_conv,
|
||||
use_emb,
|
||||
use_multiplier,
|
||||
use_time
|
||||
):
|
||||
|
||||
# Get the list of LoRA files based off of filepath.
|
||||
lora_file_names = [x for x in lora_files_selection if x != "None"]
|
||||
|
||||
if len(self.lora_files) <= 0:
|
||||
paths_with_metadata, lora_names = self.get_lora_files()
|
||||
self.lora_files = paths_with_metadata.copy()
|
||||
|
||||
lora_files = self.get_loras_to_process(lora_file_names)
|
||||
|
||||
# Load multiple LoRAs
|
||||
lora_files_list = []
|
||||
|
||||
# Load our advanced options in a list
|
||||
advanced_options = [
|
||||
use_bias,
|
||||
use_linear,
|
||||
use_conv,
|
||||
use_emb,
|
||||
use_multiplier,
|
||||
use_time
|
||||
]
|
||||
|
||||
# Save the previous alpha value so we can re-run the LoRA with new values.
|
||||
alpha_changed = self.handle_alpha_change(lora_alpha, p.sd_model)
|
||||
|
||||
# If an advanced option changes, re-run with new options
|
||||
options_changed = self.handle_options_change(advanced_options, p.sd_model)
|
||||
|
||||
# Check if we changed our LoRA models we are loading
|
||||
lora_changed = self.previous_lora_file_names != lora_file_names
|
||||
|
||||
first_lora_init = not self.is_lora_loaded(p.sd_model)
|
||||
|
||||
# If the LoRA is still loaded, unload it.
|
||||
self.handle_lora_start(lora_files, p.sd_model)
|
||||
|
||||
p.sd_model.eval()
|
||||
for param in p.sd_model.parameters():
|
||||
if param.requires_grad:
|
||||
param.requires_grad_(False)
|
||||
|
||||
can_use_lora = self.can_use_lora(p.sd_model)
|
||||
|
||||
lora_params_changed = any([alpha_changed, lora_changed, options_changed])
|
||||
|
||||
# Process LoRA
|
||||
if can_use_lora or lora_params_changed:
|
||||
|
||||
if len(lora_files) == 0: return
|
||||
|
||||
lora_alpha = (lora_alpha * use_multiplier) / len(lora_files)
|
||||
|
||||
lora_files_list = self.load_loras_from_list(lora_files)
|
||||
|
||||
args = [p, lora_files_list, use_bias, use_time, use_conv, use_emb, use_linear, lora_alpha]
|
||||
|
||||
if lora_params_changed and not first_lora_init:
|
||||
self.log("Resetting weights to reflect changed options.")
|
||||
|
||||
undo_args = args.copy()
|
||||
undo_args[1], undo_args[-1] = self.undo_merge_preprocess()
|
||||
|
||||
self.process_lora(*undo_args, undo_merge=True)
|
||||
|
||||
self.process_lora(*args, undo_merge=False)
|
||||
|
||||
self.handle_after_lora_load(
|
||||
p.sd_model,
|
||||
lora_files,
|
||||
lora_file_names,
|
||||
advanced_options,
|
||||
alpha_changed,
|
||||
lora_alpha
|
||||
)
|
||||
|
||||
if len(lora_files) > 0 and not all([can_use_lora, lora_params_changed]):
|
||||
self.log(f"Using loaded LoRAs: {', '.join(lora_file_names)}")
|
||||
|
||||
StableLoraScriptInstance = StableLoraScript()
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
import os
|
||||
import glob
|
||||
import torch
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
from modules.shared import opts, cmd_opts, state, cmd_opts, sd_model
|
||||
from modules import sd_hijack
|
||||
from modules.sd_models import read_state_dict
|
||||
|
||||
class StableLoraProcessor:
|
||||
def __init__(self):
|
||||
self.lora_loaded = 'lora_loaded'
|
||||
self.previous_lora_alpha = 1
|
||||
self.current_sd_checkpoint = ""
|
||||
self.previous_lora_file_names = []
|
||||
self.previous_advanced_options = []
|
||||
self.lora_files = []
|
||||
|
||||
def get_lora_files(self):
|
||||
paths_with_metadata = []
|
||||
paths = glob.glob(os.path.join(cmd_opts.lora_dir, '**/*.safetensors'), recursive=True)
|
||||
lora_names = []
|
||||
|
||||
for lora_path in paths:
|
||||
with safe_open(lora_path, 'pt') as lora_file:
|
||||
metadata = lora_file.metadata()
|
||||
if metadata is not None and 'stable_lora_text_to_video' in metadata.keys():
|
||||
metadata['path'] = lora_path
|
||||
metadata['lora_name'] = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
paths_with_metadata.append(metadata)
|
||||
|
||||
if len(paths_with_metadata) > 0:
|
||||
lora_names = [x['lora_name'] for x in paths_with_metadata]
|
||||
|
||||
return paths_with_metadata, lora_names
|
||||
|
||||
def key_name_match(self, value, key, name):
|
||||
return value in key and name == key.split(f".{value}")[0]
|
||||
|
||||
def is_lora_match(self, key, name):
|
||||
return self.key_name_match('lora_A', key, name)
|
||||
|
||||
def is_bias_match(self, key, name):
|
||||
return self.key_name_match("bias", key, name)
|
||||
|
||||
def lora_rank(self, weight): return min(weight.shape)
|
||||
|
||||
def get_lora_alpha(self, alpha):
|
||||
return alpha
|
||||
|
||||
def process_lora_weight(self, weight, lora_weight, alpha, undo_merge=False):
|
||||
new_weight = weight.detach().clone()
|
||||
|
||||
if not undo_merge:
|
||||
new_weight += lora_weight.to(weight.device, weight.dtype) * alpha
|
||||
else:
|
||||
new_weight -= lora_weight.to(weight.device, weight.dtype) * alpha
|
||||
|
||||
return torch.nn.Parameter(new_weight.to(weight.device, weight.dtype))
|
||||
|
||||
def lora_linear_forward(
|
||||
self,
|
||||
weight,
|
||||
lora_A,
|
||||
lora_B,
|
||||
alpha,
|
||||
undo_merge=False,
|
||||
*args
|
||||
):
|
||||
l_alpha = self.get_lora_alpha(alpha)
|
||||
lora_weight = (lora_B @ lora_A)
|
||||
|
||||
return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge)
|
||||
|
||||
def lora_conv_forward(
|
||||
self,
|
||||
weight,
|
||||
lora_A,
|
||||
lora_B,
|
||||
alpha,
|
||||
undo_merge=False,
|
||||
is_temporal=False,
|
||||
*args
|
||||
):
|
||||
l_alpha = self.get_lora_alpha(alpha)
|
||||
view_shape = weight.shape
|
||||
|
||||
if is_temporal:
|
||||
i, o, k = weight.shape[:3]
|
||||
view_shape = (i, o, k, k, 1)
|
||||
|
||||
lora_weight = (lora_B @ lora_A).view(view_shape)
|
||||
|
||||
if is_temporal:
|
||||
lora_weight = torch.mean(lora_weight, dim=-2, keepdim=True)
|
||||
|
||||
return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge)
|
||||
|
||||
def lora_emb_forward(self, lora_A, lora_B, alpha, undo_merge=False, *args):
|
||||
l_alpha = self.get_lora_alpha(alpha)
|
||||
|
||||
return (lora_B @ lora_A).transpose(0, 1) * l_alpha
|
||||
|
||||
def is_lora_loaded(self, sd_model):
|
||||
return hasattr(sd_model, self.lora_loaded)
|
||||
|
||||
def get_loras_to_process(self, lora_files):
|
||||
lora_files_to_load = []
|
||||
|
||||
for file_name in lora_files:
|
||||
if len(self.lora_files) > 0:
|
||||
for f in self.lora_files:
|
||||
if file_name == f['lora_name']:
|
||||
lora_files_to_load.append(f['path'])
|
||||
|
||||
return lora_files_to_load
|
||||
|
||||
def handle_lora_load(self, sd_model, lora_files_list, set_lora_loaded=False):
|
||||
if not hasattr(sd_model, self.lora_loaded) and set_lora_loaded:
|
||||
setattr(sd_model, self.lora_loaded, True)
|
||||
|
||||
if self.is_lora_loaded(sd_model) and not set_lora_loaded:
|
||||
self.process_lora(p, lora_files_list, undo_merge=True)
|
||||
delattr(sd_model, self.lora_loaded)
|
||||
|
||||
def handle_alpha_change(self, lora_alpha, model):
|
||||
return (lora_alpha != self.previous_lora_alpha) \
|
||||
and self.is_lora_loaded(model)
|
||||
|
||||
def handle_options_change(self, options, model):
|
||||
return (options != self.previous_advanced_options) \
|
||||
and self.is_lora_loaded(model)
|
||||
|
||||
def handle_lora_start(self, lora_files, model):
|
||||
if len(lora_files) == 0 and self.is_lora_loaded(model):
|
||||
self.handle_lora_load(model, lora_files, set_lora_loaded=False)
|
||||
|
||||
self.log(f"Unloaded previously loaded LoRA files")
|
||||
return
|
||||
|
||||
def can_use_lora(self, model):
|
||||
return not self.is_lora_loaded(model)
|
||||
|
||||
def load_loras_from_list(self, lora_files):
|
||||
lora_files_list = []
|
||||
|
||||
for lora_file in lora_files:
|
||||
LORA_FILE = lora_file.split('/')[-1]
|
||||
LORA_DIR = cmd_opts.lora_dir
|
||||
LORA_PATH = f"{LORA_DIR}/{LORA_FILE}"
|
||||
|
||||
lora_model_text_path = f"{LORA_DIR}/text_{LORA_FILE}"
|
||||
lora_text_exists = os.path.exists(lora_model_text_path)
|
||||
|
||||
is_safetensors = LORA_PATH.endswith('.safetensors')
|
||||
load_method = load_file if is_safetensors else torch.load
|
||||
|
||||
lora_model = load_method(LORA_PATH)
|
||||
|
||||
lora_files_list.append(lora_model)
|
||||
|
||||
return lora_files_list
|
||||
|
||||
def handle_after_lora_load(
|
||||
self,
|
||||
model,
|
||||
lora_files,
|
||||
lora_file_names,
|
||||
advanced_options,
|
||||
alpha_changed,
|
||||
lora_alpha
|
||||
):
|
||||
lora_summary = []
|
||||
self.handle_lora_load(model, lora_files, set_lora_loaded=True)
|
||||
self.previous_lora_file_names = lora_file_names
|
||||
self.previous_advanced_options = advanced_options
|
||||
self.previous_lora_alpha = lora_alpha
|
||||
|
||||
for lora_file_name in lora_file_names:
|
||||
if self.is_lora_loaded(model):
|
||||
lora_summary.append(f"{lora_file_name.split('.')[0]}")
|
||||
|
||||
if len(lora_summary) > 0:
|
||||
self.log("Using LoRA(s):", *lora_summary)
|
||||
|
||||
if alpha_changed:
|
||||
self.log("Alpha changed successfully.")
|
||||
|
||||
def undo_merge_preprocess(self):
|
||||
previous_lora_files_list = self.get_loras_to_process(self.previous_lora_file_names)
|
||||
previous_lora_files = self.load_loras_from_list(previous_lora_files_list)
|
||||
|
||||
return previous_lora_files, self.previous_lora_alpha
|
||||
|
||||
@torch.autocast('cuda')
|
||||
def process_lora(
|
||||
self,
|
||||
p,
|
||||
lora_files_list,
|
||||
use_bias,
|
||||
use_time,
|
||||
use_conv,
|
||||
use_emb,
|
||||
use_linear,
|
||||
lora_alpha,
|
||||
undo_merge=False
|
||||
):
|
||||
for n, m in p.sd_model.named_modules():
|
||||
for lora_model in lora_files_list:
|
||||
for k, v in lora_model.items():
|
||||
|
||||
# If there is bias in the LoRA, add it here.
|
||||
if self.is_bias_match(k, n) and use_bias:
|
||||
if m.bias is None:
|
||||
m.bias = torch.nn.Parameter(v.to(self.device, dtype=self.dtype))
|
||||
else:
|
||||
m.bias.weight = v.to(self.device, dtype=self.dtype)
|
||||
|
||||
if self.is_lora_match(k, n):
|
||||
lora_A = lora_model[f"{n}.lora_A"].to(self.device, dtype=self.dtype)
|
||||
lora_B = lora_model[f"{n}.lora_B"].to(self.device, dtype=self.dtype)
|
||||
|
||||
forward_args = [m.weight, lora_A, lora_B, lora_alpha]
|
||||
|
||||
if isinstance(m, torch.nn.Linear) and use_linear:
|
||||
if 'proj' in n:
|
||||
forward_args[1], forward_args[2] = map(lambda l: l.squeeze(-1), (lora_A, lora_B))
|
||||
|
||||
m.weight = self.lora_linear_forward(*forward_args, undo_merge=undo_merge)
|
||||
|
||||
if isinstance(m, torch.nn.Conv2d) and use_conv:
|
||||
m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=False)
|
||||
|
||||
if isinstance(m, torch.nn.Conv3d) and use_conv and use_time:
|
||||
m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=True)
|
||||
|
||||
if isinstance(m, torch.nn.Embedding) and use_emb:
|
||||
embedding_weight = self.lora_emb_forward(lora_A, lora_B, lora_alpha, undo_merge=undo_merge)
|
||||
new_embedding_weight = torch.nn.Embedding.from_pretrained(embedding_weight)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import gradio as gr
|
||||
|
||||
class Text2VideoExtension(object):
|
||||
"""
|
||||
A simple base class that sets a definitive way to process extensions
|
||||
"""
|
||||
|
||||
def __init__(self, extension_name: str = '', extension_title: str = ''):
|
||||
|
||||
self.extension_name = extension_name
|
||||
self.extension_title = extension_title
|
||||
self.return_args_delimiter = f"extension_{extension_name}"
|
||||
|
||||
def return_ui_inputs(self, return_args: list = [] ):
|
||||
"""
|
||||
All extensions should use this method to return Gradio inputs.
|
||||
This allows for tracking the inputs using a delimiter.
|
||||
Arguments are automatically processed and returned.
|
||||
|
||||
Output: <my_extension_name> + [arg1, arg2, arg3] + <my_extension_name>
|
||||
"""
|
||||
|
||||
delimiter = gr.State(self.return_args_delimiter)
|
||||
return [delimiter] + return_args + [delimiter]
|
||||
|
||||
def process_extension_args(self, all_args: list = []):
|
||||
"""
|
||||
Processes all extension arguments and appends them into a list.
|
||||
The filtered arguments are piped into the extension's process method.
|
||||
"""
|
||||
|
||||
can_append = False
|
||||
extension_args = []
|
||||
|
||||
for value in all_args:
|
||||
if value == self.return_args_delimiter and not can_append:
|
||||
can_append = True
|
||||
continue
|
||||
|
||||
if can_append:
|
||||
if value == self.return_args_delimiter:
|
||||
break
|
||||
else:
|
||||
extension_args.append(value)
|
||||
|
||||
return extension_args
|
||||
|
||||
def log(self, message: str = '', *args):
|
||||
"""
|
||||
Choose to print a log specific to the extension.
|
||||
"""
|
||||
OKGREEN = '\033[92m'
|
||||
ENDC = '\033[0m'
|
||||
|
||||
title = self.extension_title
|
||||
message = f"Extension {title}: {message} " + ', '.join(args)
|
||||
print(OKGREEN + message + ENDC)
|
||||
|
|
@ -27,7 +27,7 @@ def run(*args):
|
|||
try:
|
||||
print(f'text2video — The model selected is: {args_dict["model"]} ({args_dict["model_type"]}-like)')
|
||||
if model_type == 'ModelScope':
|
||||
vids_pack = process_modelscope(args_dict)
|
||||
vids_pack = process_modelscope(args_dict, args)
|
||||
elif model_type == 'VideoCrafter (WIP)':
|
||||
vids_pack = process_videocrafter(args_dict)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ from t2v_helpers.render import run
|
|||
import t2v_helpers.args as args
|
||||
from t2v_helpers.args import setup_text2video_settings_dictionary
|
||||
from webui import wrap_gradio_gpu_call
|
||||
from stable_lora.scripts.lora_webui import StableLoraScriptInstance
|
||||
StableLoraScript = StableLoraScriptInstance
|
||||
|
||||
def process(*args):
|
||||
# weird PATH stuff
|
||||
|
|
@ -46,6 +48,7 @@ def on_ui_tabs():
|
|||
with gr.Row(elem_id='t2v-core').style(equal_height=False, variant='compact'):
|
||||
with gr.Column(scale=1, variant='panel'):
|
||||
components = setup_text2video_settings_dictionary()
|
||||
stable_lora_ui = StableLoraScript.ui()
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
with gr.Row(elem_id=f"text2vid_generate_box", variant='compact', elem_classes="generate-box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"text2vid_interrupt", elem_classes="generate-box-interrupt")
|
||||
|
|
@ -91,7 +94,7 @@ def on_ui_tabs():
|
|||
# , extra_outputs=[None, '', '']),
|
||||
fn=wrap_gradio_gpu_call(process),
|
||||
_js="submit_txt2vid",
|
||||
inputs=[dummy_component1, dummy_component2] + [components[name] for name in args.get_component_names()],
|
||||
inputs=[dummy_component1, dummy_component2] + [components[name] for name in args.get_component_names()] + stable_lora_ui,
|
||||
outputs=[
|
||||
dummy_component1, dummy_component2,
|
||||
],
|
||||
|
|
|
|||
Loading…
Reference in New Issue