mirror of https://github.com/bmaltais/kohya_ss
Add support to conver model to LCM model
parent
161b8b486f
commit
2e37d0de4c
|
|
@ -4,6 +4,7 @@ from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
|||
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||
from library.extract_lora_gui import gradio_extract_lora_tab
|
||||
from library.convert_lcm_gui import gradio_convert_lcm_tab
|
||||
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||
from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
||||
from library.merge_lycoris_gui import gradio_merge_lycoris_tab
|
||||
|
|
@ -24,6 +25,7 @@ class LoRATools:
|
|||
'This section provide LoRA tools to help setup your dataset...'
|
||||
)
|
||||
gradio_extract_dylora_tab(headless=headless)
|
||||
gradio_convert_lcm_tab(headless=headless)
|
||||
gradio_extract_lora_tab(headless=headless)
|
||||
gradio_extract_lycoris_locon_tab(headless=headless)
|
||||
gradio_merge_lora_tab = GradioMergeLoRaTab()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
import subprocess
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
)
|
||||
from library.custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe"
|
||||
|
||||
|
||||
def convert_lcm(
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type
|
||||
):
|
||||
run_cmd = f'{PYTHON} "{os.path.join("tools","lcm_convert.py")}"'
|
||||
# Construct the command to run the script
|
||||
run_cmd += f' --name "{name}"'
|
||||
run_cmd += f' --model "{model_path}"'
|
||||
run_cmd += f" --lora-scale {lora_scale}"
|
||||
|
||||
if model_type == "SDXL":
|
||||
run_cmd += f" --sdxl"
|
||||
if model_type == "SSD-1B":
|
||||
run_cmd += f" --ssd-1b"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
# Run the command
|
||||
if os.name == "posix":
|
||||
os.system(run_cmd)
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
# Return a success message
|
||||
log.info("Done extracting...")
|
||||
|
||||
|
||||
def gradio_convert_lcm_tab(headless=False):
|
||||
with gr.Tab("Convert to LCM"):
|
||||
gr.Markdown("This utility convert a model to an LCM model.")
|
||||
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LCM model types", visible=False)
|
||||
model_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
model_ext_name = gr.Textbox(value="Model types", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
model_path = gr.Textbox(
|
||||
label="Stable Diffusion model to convert to LCM",
|
||||
interactive=True,
|
||||
)
|
||||
button_model_path_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_path_file.click(
|
||||
get_file_path,
|
||||
inputs=[model_path, model_ext, model_ext_name],
|
||||
outputs=model_path,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
name = gr.Textbox(
|
||||
label="Name of the new LCM model",
|
||||
placeholder="Path to the LCM file to create",
|
||||
interactive=True,
|
||||
)
|
||||
button_name = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
button_name.click(
|
||||
get_saveasfilename_path,
|
||||
inputs=[name, lora_ext, lora_ext_name],
|
||||
outputs=name,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lora_scale = gr.Slider(
|
||||
label="Strength of the LCM",
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
step=0.1,
|
||||
value=1.0,
|
||||
interactive=True,
|
||||
)
|
||||
# with gr.Row():
|
||||
# no_half = gr.Checkbox(label="Convert the new LCM model to FP32", value=False)
|
||||
model_type = gr.Dropdown(
|
||||
label="Model type", choices=["SD15", "SDXL", "SD-1B"], value="SD15"
|
||||
)
|
||||
|
||||
extract_button = gr.Button("Extract LCM")
|
||||
|
||||
extract_button.click(
|
||||
convert_lcm,
|
||||
inputs=[
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
@ -1284,7 +1284,7 @@ def lora_tab(
|
|||
train_on_input = gr.Checkbox(
|
||||
value=True,
|
||||
label="iA3 train on input",
|
||||
info="Set if we change the information going into the system (True) or the information coming out of it (False)."
|
||||
info="Set if we change the information going into the system (True) or the information coming out of it (False).",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ accelerate==0.23.0
|
|||
aiofiles==23.2.1
|
||||
altair==4.2.2
|
||||
dadaptation==3.1
|
||||
diffusers[torch]==0.21.4
|
||||
diffusers[torch]==0.24.0
|
||||
easygui==0.98.3
|
||||
einops==0.6.0
|
||||
fairscale==0.4.13
|
||||
ftfy==6.1.1
|
||||
gradio==3.36.1
|
||||
huggingface-hub==0.15.1
|
||||
huggingface-hub==0.19.4
|
||||
# for loading Diffusers' SDXL
|
||||
invisible-watermark==0.2.0
|
||||
lion-pytorch==0.0.6
|
||||
|
|
@ -21,6 +21,7 @@ lycoris_lora==2.0.0
|
|||
# for WD14 captioning (tensorflow)
|
||||
# tensorflow==2.14.0
|
||||
# for WD14 captioning (onnx)
|
||||
omegaconf==2.3.0
|
||||
onnx==1.14.1
|
||||
onnxruntime-gpu==1.16.0
|
||||
# onnxruntime==1.16.0
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ mkdir ".\logs\setup" > nul 2>&1
|
|||
call .\venv\Scripts\deactivate.bat
|
||||
|
||||
:: Calling external python program to check for local modules
|
||||
python .\setup\check_local_modules.py
|
||||
:: python .\setup\check_local_modules.py
|
||||
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
import argparse
|
||||
import torch
|
||||
from library.custom_logging import setup_logging
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, LCMScheduler
|
||||
from library.sdxl_model_util import convert_diffusers_unet_state_dict_to_sdxl, sdxl_original_unet, save_stable_diffusion_checkpoint, _load_state_dict_on_device as load_state_dict_on_device
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
# Initialize logging
|
||||
logger = setup_logging()
|
||||
|
||||
def parse_command_line_arguments():
|
||||
argument_parser = argparse.ArgumentParser("lcm_convert")
|
||||
argument_parser.add_argument("--name", help="Name of the new LCM model", required=True, type=str)
|
||||
argument_parser.add_argument("--model", help="A model to convert", required=True, type=str)
|
||||
argument_parser.add_argument("--lora-scale", default=1.0, help="Strength of the LCM", type=float)
|
||||
argument_parser.add_argument("--sdxl", action="store_true", help="Use SDXL models")
|
||||
argument_parser.add_argument("--ssd-1b", action="store_true", help="Use SSD-1B models")
|
||||
return argument_parser.parse_args()
|
||||
|
||||
def load_diffusion_pipeline(command_line_args):
|
||||
if command_line_args.sdxl or command_line_args.ssd_1b:
|
||||
return StableDiffusionXLPipeline.from_single_file(command_line_args.model)
|
||||
else:
|
||||
return StableDiffusionPipeline.from_single_file(command_line_args.model)
|
||||
|
||||
def convert_and_save_diffusion_model(diffusion_pipeline, command_line_args):
|
||||
diffusion_pipeline.scheduler = LCMScheduler.from_config(diffusion_pipeline.scheduler.config)
|
||||
lora_weight_file_path = "latent-consistency/lcm-lora-" + ("sdxl" if command_line_args.sdxl else "ssd-1b" if command_line_args.ssd_1b else "sdv1-5")
|
||||
diffusion_pipeline.load_lora_weights(lora_weight_file_path)
|
||||
diffusion_pipeline.fuse_lora(lora_scale=command_line_args.lora_scale)
|
||||
|
||||
diffusion_pipeline = diffusion_pipeline.to(dtype=torch.float16)
|
||||
logger.info("Saving file...")
|
||||
|
||||
text_encoder_primary = diffusion_pipeline.text_encoder
|
||||
text_encoder_secondary = diffusion_pipeline.text_encoder_2
|
||||
variational_autoencoder = diffusion_pipeline.vae
|
||||
unet_network = diffusion_pipeline.unet
|
||||
|
||||
del diffusion_pipeline
|
||||
|
||||
state_dict = convert_diffusers_unet_state_dict_to_sdxl(unet_network.state_dict())
|
||||
with init_empty_weights():
|
||||
unet_network = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
load_state_dict_on_device(unet_network, state_dict, device="cuda", dtype=torch.float16)
|
||||
|
||||
save_stable_diffusion_checkpoint(
|
||||
command_line_args.name,
|
||||
text_encoder_primary,
|
||||
text_encoder_secondary,
|
||||
unet_network,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
variational_autoencoder,
|
||||
None,
|
||||
None,
|
||||
torch.float16,
|
||||
)
|
||||
|
||||
logger.info("...done saving")
|
||||
|
||||
def main():
|
||||
command_line_args = parse_command_line_arguments()
|
||||
try:
|
||||
diffusion_pipeline = load_diffusion_pipeline(command_line_args)
|
||||
convert_and_save_diffusion_model(diffusion_pipeline, command_line_args)
|
||||
except Exception as error:
|
||||
logger.error(f"An error occurred: {error}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue