Add Stable LoRA support

pull/201/head
ExponentialML 2023-07-11 17:14:13 -07:00
parent 3f4a109a69
commit 2dad959835
6 changed files with 529 additions and 3 deletions

View File

@ -22,6 +22,7 @@ from t2v_helpers.args import get_outdir, process_args
import t2v_helpers.args as t2v_helpers_args
from modules import shared, sd_hijack, lowvram
from modules.shared import opts, devices, state
from stable_lora.scripts.lora_webui import gr_inputs_list, StableLoraScriptInstance
import os
pipe = None
@ -29,7 +30,7 @@ pipe = None
def setup_pipeline(model_name):
return TextToVideoSynthesis(get_model_location(model_name))
def process_modelscope(args_dict):
def process_modelscope(args_dict, extra_args=None):
args, video_args = process_args(args_dict)
global pipe
@ -63,6 +64,11 @@ def process_modelscope(args_dict):
if pipe is None or pipe is not None and args.model is not None and get_model_location(args.model) != pipe.model_dir:
pipe = setup_pipeline(args.model)
#TODO Wrap this in a list so that we can process this for future extensions.
stable_lora_processor = StableLoraScriptInstance
stable_lora_args = stable_lora_processor.process_extension_args(all_args=extra_args)
stable_lora_processor.process(pipe, *stable_lora_args)
pipe.keep_in_vram = opts.data.get("modelscope_deforum_keep_model_in_vram") if opts.data is not None and opts.data.get("modelscope_deforum_keep_model_in_vram") is not None else 'None'
device = devices.get_optimal_device()

View File

@ -0,0 +1,220 @@
import loralib as loralb
import torch
import glob
from safetensors.torch import load_file
from types import SimpleNamespace
from safetensors import safe_open
from einops import rearrange
import gradio as gr
import os
import json
from modules import images, script_callbacks
from modules.shared import opts, cmd_opts, state, cmd_opts, sd_model
from modules.sd_models import read_state_dict
from stable_lora.stable_utils.lora_processor import StableLoraProcessor
from t2v_helpers.extensions_utils import Text2VideoExtension
EXTENSION_TITLE = "Stable LoRA"
EXTENSION_NAME = EXTENSION_TITLE.replace(' ', '_').lower()
gr_inputs_list = [
"lora_file_selection",
"lora_alpha",
"refresh_button",
"use_bias",
"use_linear",
"use_conv",
"use_emb",
"use_time",
"use_multiplier"
]
gr_inputs_dict = {v: v for v in gr_inputs_list}
GradioInputsIds = SimpleNamespace(**gr_inputs_dict)
class StableLoraScript(Text2VideoExtension, StableLoraProcessor):
def __init__(self):
StableLoraProcessor.__init__(self)
Text2VideoExtension.__init__(self, EXTENSION_NAME, EXTENSION_TITLE)
self.device = 'cuda'
self.dtype = torch.float16
def title(self):
return EXTENSION_TITLE
def refresh_models(self, *args):
paths_with_metadata, lora_names = self.get_lora_files()
self.lora_files = paths_with_metadata.copy()
return gr.Dropdown.update(value=[], choices=lora_names)
def ui(self):
paths_with_metadata, lora_names = self.get_lora_files()
self.lora_files = paths_with_metadata.copy()
REPOSITORY_LINK = "https://github.com/ExponentialML/Text-To-Video-Finetuning"
with gr.Accordion(label=EXTENSION_TITLE, open=False) as stable_lora_section:
with gr.Blocks(analytics_enabled=False):
with gr.Row():
with gr.Column():
gr.HTML("<h2>Load a Trained LoRA File.</h2>")
gr.HTML(
"""
<h3 style='color: crimson; font-weight: bold;'>
Only Stable LoRA files are supported.
</h3>
"""
)
gr.HTML(f"""
<a href='{REPOSITORY_LINK}'>
To train a Stable LoRA file, use the finetune repository by clicking here.
</a>"""
)
lora_files_selection = gr.Dropdown(
label="Available Models",
elem_id=GradioInputsIds.lora_file_selection,
choices=lora_names,
value=[],
multiselect=True,
)
lora_alpha = gr.Slider(
minimum=0,
maximum=1,
value=1,
step=0.05,
elem_id=GradioInputsIds.lora_alpha,
label="LoRA Weight"
)
refresh_button = gr.Button(
value="Refresh Models",
elem_id=GradioInputsIds.refresh_button
)
refresh_button.click(
self.refresh_models,
lora_files_selection,
lora_files_selection
)
with gr.Accordion(label="Advanced Settings", open=False, visible=False):
with gr.Column():
use_bias = gr.Checkbox(label="Enable Bias", elem_id=GradioInputsIds.use_bias, value=lambda: True)
use_linear = gr.Checkbox(label="Enable Linears", elem_id=GradioInputsIds.use_linear, value=lambda: True)
use_conv = gr.Checkbox(label="Enable Convolutions", elem_id=GradioInputsIds.use_conv, value=lambda: True)
use_emb = gr.Checkbox(label="Enable Embeddings", elem_id=GradioInputsIds.use_emb, value=lambda: True)
use_time = gr.Checkbox(label="Enable Time", elem_id=GradioInputsIds.use_time, value=lambda: True)
with gr.Column():
use_multiplier = gr.Number(
label="Alpha Multiplier",
elem_id=GradioInputsIds.use_multiplier,
value=1,
)
return self.return_ui_inputs(
return_args=[
lora_files_selection,
lora_alpha,
use_bias,
use_linear,
use_conv,
use_emb,
use_multiplier,
use_time
]
)
@torch.no_grad()
def process(
self,
p,
lora_files_selection,
lora_alpha,
use_bias,
use_linear,
use_conv,
use_emb,
use_multiplier,
use_time
):
# Get the list of LoRA files based off of filepath.
lora_file_names = [x for x in lora_files_selection if x != "None"]
if len(self.lora_files) <= 0:
paths_with_metadata, lora_names = self.get_lora_files()
self.lora_files = paths_with_metadata.copy()
lora_files = self.get_loras_to_process(lora_file_names)
# Load multiple LoRAs
lora_files_list = []
# Load our advanced options in a list
advanced_options = [
use_bias,
use_linear,
use_conv,
use_emb,
use_multiplier,
use_time
]
# Save the previous alpha value so we can re-run the LoRA with new values.
alpha_changed = self.handle_alpha_change(lora_alpha, p.sd_model)
# If an advanced option changes, re-run with new options
options_changed = self.handle_options_change(advanced_options, p.sd_model)
# Check if we changed our LoRA models we are loading
lora_changed = self.previous_lora_file_names != lora_file_names
first_lora_init = not self.is_lora_loaded(p.sd_model)
# If the LoRA is still loaded, unload it.
self.handle_lora_start(lora_files, p.sd_model)
p.sd_model.eval()
for param in p.sd_model.parameters():
if param.requires_grad:
param.requires_grad_(False)
can_use_lora = self.can_use_lora(p.sd_model)
lora_params_changed = any([alpha_changed, lora_changed, options_changed])
# Process LoRA
if can_use_lora or lora_params_changed:
if len(lora_files) == 0: return
lora_alpha = (lora_alpha * use_multiplier) / len(lora_files)
lora_files_list = self.load_loras_from_list(lora_files)
args = [p, lora_files_list, use_bias, use_time, use_conv, use_emb, use_linear, lora_alpha]
if lora_params_changed and not first_lora_init:
self.log("Resetting weights to reflect changed options.")
undo_args = args.copy()
undo_args[1], undo_args[-1] = self.undo_merge_preprocess()
self.process_lora(*undo_args, undo_merge=True)
self.process_lora(*args, undo_merge=False)
self.handle_after_lora_load(
p.sd_model,
lora_files,
lora_file_names,
advanced_options,
alpha_changed,
lora_alpha
)
if len(lora_files) > 0 and not all([can_use_lora, lora_params_changed]):
self.log(f"Using loaded LoRAs: {', '.join(lora_file_names)}")
StableLoraScriptInstance = StableLoraScript()

View File

@ -0,0 +1,240 @@
import os
import glob
import torch
from safetensors.torch import load_file
from safetensors import safe_open
from modules.shared import opts, cmd_opts, state, cmd_opts, sd_model
from modules import sd_hijack
from modules.sd_models import read_state_dict
class StableLoraProcessor:
def __init__(self):
self.lora_loaded = 'lora_loaded'
self.previous_lora_alpha = 1
self.current_sd_checkpoint = ""
self.previous_lora_file_names = []
self.previous_advanced_options = []
self.lora_files = []
def get_lora_files(self):
paths_with_metadata = []
paths = glob.glob(os.path.join(cmd_opts.lora_dir, '**/*.safetensors'), recursive=True)
lora_names = []
for lora_path in paths:
with safe_open(lora_path, 'pt') as lora_file:
metadata = lora_file.metadata()
if metadata is not None and 'stable_lora_text_to_video' in metadata.keys():
metadata['path'] = lora_path
metadata['lora_name'] = os.path.splitext(os.path.basename(lora_path))[0]
paths_with_metadata.append(metadata)
if len(paths_with_metadata) > 0:
lora_names = [x['lora_name'] for x in paths_with_metadata]
return paths_with_metadata, lora_names
def key_name_match(self, value, key, name):
return value in key and name == key.split(f".{value}")[0]
def is_lora_match(self, key, name):
return self.key_name_match('lora_A', key, name)
def is_bias_match(self, key, name):
return self.key_name_match("bias", key, name)
def lora_rank(self, weight): return min(weight.shape)
def get_lora_alpha(self, alpha):
return alpha
def process_lora_weight(self, weight, lora_weight, alpha, undo_merge=False):
new_weight = weight.detach().clone()
if not undo_merge:
new_weight += lora_weight.to(weight.device, weight.dtype) * alpha
else:
new_weight -= lora_weight.to(weight.device, weight.dtype) * alpha
return torch.nn.Parameter(new_weight.to(weight.device, weight.dtype))
def lora_linear_forward(
self,
weight,
lora_A,
lora_B,
alpha,
undo_merge=False,
*args
):
l_alpha = self.get_lora_alpha(alpha)
lora_weight = (lora_B @ lora_A)
return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge)
def lora_conv_forward(
self,
weight,
lora_A,
lora_B,
alpha,
undo_merge=False,
is_temporal=False,
*args
):
l_alpha = self.get_lora_alpha(alpha)
view_shape = weight.shape
if is_temporal:
i, o, k = weight.shape[:3]
view_shape = (i, o, k, k, 1)
lora_weight = (lora_B @ lora_A).view(view_shape)
if is_temporal:
lora_weight = torch.mean(lora_weight, dim=-2, keepdim=True)
return self.process_lora_weight(weight, lora_weight, l_alpha, undo_merge=undo_merge)
def lora_emb_forward(self, lora_A, lora_B, alpha, undo_merge=False, *args):
l_alpha = self.get_lora_alpha(alpha)
return (lora_B @ lora_A).transpose(0, 1) * l_alpha
def is_lora_loaded(self, sd_model):
return hasattr(sd_model, self.lora_loaded)
def get_loras_to_process(self, lora_files):
lora_files_to_load = []
for file_name in lora_files:
if len(self.lora_files) > 0:
for f in self.lora_files:
if file_name == f['lora_name']:
lora_files_to_load.append(f['path'])
return lora_files_to_load
def handle_lora_load(self, sd_model, lora_files_list, set_lora_loaded=False):
if not hasattr(sd_model, self.lora_loaded) and set_lora_loaded:
setattr(sd_model, self.lora_loaded, True)
if self.is_lora_loaded(sd_model) and not set_lora_loaded:
self.process_lora(p, lora_files_list, undo_merge=True)
delattr(sd_model, self.lora_loaded)
def handle_alpha_change(self, lora_alpha, model):
return (lora_alpha != self.previous_lora_alpha) \
and self.is_lora_loaded(model)
def handle_options_change(self, options, model):
return (options != self.previous_advanced_options) \
and self.is_lora_loaded(model)
def handle_lora_start(self, lora_files, model):
if len(lora_files) == 0 and self.is_lora_loaded(model):
self.handle_lora_load(model, lora_files, set_lora_loaded=False)
self.log(f"Unloaded previously loaded LoRA files")
return
def can_use_lora(self, model):
return not self.is_lora_loaded(model)
def load_loras_from_list(self, lora_files):
lora_files_list = []
for lora_file in lora_files:
LORA_FILE = lora_file.split('/')[-1]
LORA_DIR = cmd_opts.lora_dir
LORA_PATH = f"{LORA_DIR}/{LORA_FILE}"
lora_model_text_path = f"{LORA_DIR}/text_{LORA_FILE}"
lora_text_exists = os.path.exists(lora_model_text_path)
is_safetensors = LORA_PATH.endswith('.safetensors')
load_method = load_file if is_safetensors else torch.load
lora_model = load_method(LORA_PATH)
lora_files_list.append(lora_model)
return lora_files_list
def handle_after_lora_load(
self,
model,
lora_files,
lora_file_names,
advanced_options,
alpha_changed,
lora_alpha
):
lora_summary = []
self.handle_lora_load(model, lora_files, set_lora_loaded=True)
self.previous_lora_file_names = lora_file_names
self.previous_advanced_options = advanced_options
self.previous_lora_alpha = lora_alpha
for lora_file_name in lora_file_names:
if self.is_lora_loaded(model):
lora_summary.append(f"{lora_file_name.split('.')[0]}")
if len(lora_summary) > 0:
self.log("Using LoRA(s):", *lora_summary)
if alpha_changed:
self.log("Alpha changed successfully.")
def undo_merge_preprocess(self):
previous_lora_files_list = self.get_loras_to_process(self.previous_lora_file_names)
previous_lora_files = self.load_loras_from_list(previous_lora_files_list)
return previous_lora_files, self.previous_lora_alpha
@torch.autocast('cuda')
def process_lora(
self,
p,
lora_files_list,
use_bias,
use_time,
use_conv,
use_emb,
use_linear,
lora_alpha,
undo_merge=False
):
for n, m in p.sd_model.named_modules():
for lora_model in lora_files_list:
for k, v in lora_model.items():
# If there is bias in the LoRA, add it here.
if self.is_bias_match(k, n) and use_bias:
if m.bias is None:
m.bias = torch.nn.Parameter(v.to(self.device, dtype=self.dtype))
else:
m.bias.weight = v.to(self.device, dtype=self.dtype)
if self.is_lora_match(k, n):
lora_A = lora_model[f"{n}.lora_A"].to(self.device, dtype=self.dtype)
lora_B = lora_model[f"{n}.lora_B"].to(self.device, dtype=self.dtype)
forward_args = [m.weight, lora_A, lora_B, lora_alpha]
if isinstance(m, torch.nn.Linear) and use_linear:
if 'proj' in n:
forward_args[1], forward_args[2] = map(lambda l: l.squeeze(-1), (lora_A, lora_B))
m.weight = self.lora_linear_forward(*forward_args, undo_merge=undo_merge)
if isinstance(m, torch.nn.Conv2d) and use_conv:
m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=False)
if isinstance(m, torch.nn.Conv3d) and use_conv and use_time:
m.weight = self.lora_conv_forward(*forward_args, undo_merge=undo_merge, is_temporal=True)
if isinstance(m, torch.nn.Embedding) and use_emb:
embedding_weight = self.lora_emb_forward(lora_A, lora_B, lora_alpha, undo_merge=undo_merge)
new_embedding_weight = torch.nn.Embedding.from_pretrained(embedding_weight)

View File

@ -0,0 +1,57 @@
import gradio as gr
class Text2VideoExtension(object):
"""
A simple base class that sets a definitive way to process extensions
"""
def __init__(self, extension_name: str = '', extension_title: str = ''):
self.extension_name = extension_name
self.extension_title = extension_title
self.return_args_delimiter = f"extension_{extension_name}"
def return_ui_inputs(self, return_args: list = [] ):
"""
All extensions should use this method to return Gradio inputs.
This allows for tracking the inputs using a delimiter.
Arguments are automatically processed and returned.
Output: <my_extension_name> + [arg1, arg2, arg3] + <my_extension_name>
"""
delimiter = gr.State(self.return_args_delimiter)
return [delimiter] + return_args + [delimiter]
def process_extension_args(self, all_args: list = []):
"""
Processes all extension arguments and appends them into a list.
The filtered arguments are piped into the extension's process method.
"""
can_append = False
extension_args = []
for value in all_args:
if value == self.return_args_delimiter and not can_append:
can_append = True
continue
if can_append:
if value == self.return_args_delimiter:
break
else:
extension_args.append(value)
return extension_args
def log(self, message: str = '', *args):
"""
Choose to print a log specific to the extension.
"""
OKGREEN = '\033[92m'
ENDC = '\033[0m'
title = self.extension_title
message = f"Extension {title}: {message} " + ', '.join(args)
print(OKGREEN + message + ENDC)

View File

@ -27,7 +27,7 @@ def run(*args):
try:
print(f'text2video — The model selected is: {args_dict["model"]} ({args_dict["model_type"]}-like)')
if model_type == 'ModelScope':
vids_pack = process_modelscope(args_dict)
vids_pack = process_modelscope(args_dict, args)
elif model_type == 'VideoCrafter (WIP)':
vids_pack = process_videocrafter(args_dict)
else:

View File

@ -25,6 +25,8 @@ from t2v_helpers.render import run
import t2v_helpers.args as args
from t2v_helpers.args import setup_text2video_settings_dictionary
from webui import wrap_gradio_gpu_call
from stable_lora.scripts.lora_webui import StableLoraScriptInstance
StableLoraScript = StableLoraScriptInstance
def process(*args):
# weird PATH stuff
@ -46,6 +48,7 @@ def on_ui_tabs():
with gr.Row(elem_id='t2v-core').style(equal_height=False, variant='compact'):
with gr.Column(scale=1, variant='panel'):
components = setup_text2video_settings_dictionary()
stable_lora_ui = StableLoraScript.ui()
with gr.Column(scale=1, variant='compact'):
with gr.Row(elem_id=f"text2vid_generate_box", variant='compact', elem_classes="generate-box"):
interrupt = gr.Button('Interrupt', elem_id=f"text2vid_interrupt", elem_classes="generate-box-interrupt")
@ -91,7 +94,7 @@ def on_ui_tabs():
# , extra_outputs=[None, '', '']),
fn=wrap_gradio_gpu_call(process),
_js="submit_txt2vid",
inputs=[dummy_component1, dummy_component2] + [components[name] for name in args.get_component_names()],
inputs=[dummy_component1, dummy_component2] + [components[name] for name in args.get_component_names()] + stable_lora_ui,
outputs=[
dummy_component1, dummy_component2,
],