From 821cdda125a1a1303b0b313bd35f13b49bf8b6a6 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 3 Dec 2023 11:52:20 -0500 Subject: [PATCH] Update Lycoris --- .release | 2 +- README.md | 7 +- library/extract_lycoris_locon_gui.py | 11 +- library/merge_lycoris_gui.py | 7 +- requirements.txt | 2 +- tools/lycoris_locon_extract.py | 175 ++++++++++++++++++--------- tools/merge_lycoris.py | 124 ++++++++++++------- 7 files changed, 224 insertions(+), 104 deletions(-) diff --git a/.release b/.release index 8013ac7..6a112b8 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.2.1 +v22.2.2 diff --git a/README.md b/README.md index 828b633..fc335a2 100644 --- a/README.md +++ b/README.md @@ -651,7 +651,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b ## Change History -* 2023/11/?? (v22.2.1) +* 2023/12/03 (v22.2.2) +- Update Lycoris module to 2.0.0 (https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/README.md) +- Update Lycoris merge and extract tools +- Remove anoying warning about local pip modules that is not necessary. + +* 2023/11/20 (v22.2.1) - Fix issue with `Debiased Estimation loss` not getting properly loaded from json file. Oups. * 2023/11/15 (v22.2.0) diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index d3c19da..65ffc0c 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -25,6 +25,7 @@ def extract_lycoris_locon( base_model, output_name, device, + is_sdxl, is_v2, mode, linear_dim, @@ -58,6 +59,8 @@ def extract_lycoris_locon( return run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' + if is_sdxl: + run_cmd += f' --is_sdxl' if is_v2: run_cmd += f' --is_v2' run_cmd += f' --device {device}' @@ -196,10 +199,13 @@ def gradio_extract_lycoris_locon_tab(headless=False): value='cuda', interactive=True, ) + + is_sdxl = gr.Checkbox(label='is SDXL', value=False, interactive=True) + is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) mode = gr.Dropdown( label='Mode', - choices=['fixed', 'threshold', 'ratio', 'quantile'], + choices=['fixed', 'full', 'quantile', 'ratio', 'threshold'], value='fixed', interactive=True, ) @@ -211,6 +217,7 @@ def gradio_extract_lycoris_locon_tab(headless=False): value=1, step=1, interactive=True, + info="network dim for linear layer in fixed mode", ) conv_dim = gr.Slider( minimum=1, @@ -219,6 +226,7 @@ def gradio_extract_lycoris_locon_tab(headless=False): value=1, step=1, interactive=True, + info="network dim for conv layer in fixed mode", ) with gr.Row(visible=False) as threshold: linear_threshold = gr.Slider( @@ -312,6 +320,7 @@ def gradio_extract_lycoris_locon_tab(headless=False): base_model, output_name, device, + is_sdxl, is_v2, mode, linear_dim, diff --git a/library/merge_lycoris_gui.py b/library/merge_lycoris_gui.py index 7d56f1e..bb084ff 100644 --- a/library/merge_lycoris_gui.py +++ b/library/merge_lycoris_gui.py @@ -26,6 +26,7 @@ def merge_lycoris( output_name, dtype, device, + is_sdxl, is_v2, ): log.info('Merge model...') @@ -37,6 +38,8 @@ def merge_lycoris( run_cmd += f' --weight {weight}' run_cmd += f' --device {device}' run_cmd += f' --dtype {dtype}' + if is_sdxl: + run_cmd += f' --is_sdxl' if is_v2: run_cmd += f' --is_v2' @@ -149,12 +152,13 @@ def gradio_merge_lycoris_tab(headless=False): label='Device', choices=[ 'cpu', - # 'cuda', + 'cuda', ], value='cpu', interactive=True, ) + is_sdxl = gr.Checkbox(label='is sdxl', value=False, interactive=True) is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) merge_button = gr.Button('Merge model') @@ -168,6 +172,7 @@ def gradio_merge_lycoris_tab(headless=False): output_name, dtype, device, + is_sdxl, is_v2, ], show_progress=False, diff --git a/requirements.txt b/requirements.txt index 216e961..ffc2e38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ huggingface-hub==0.15.1 # for loading Diffusers' SDXL invisible-watermark==0.2.0 lion-pytorch==0.0.6 -lycoris_lora==1.9.0 +lycoris_lora==2.0.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/tools/lycoris_locon_extract.py b/tools/lycoris_locon_extract.py index 671aa5d..28b25ee 100644 --- a/tools/lycoris_locon_extract.py +++ b/tools/lycoris_locon_extract.py @@ -1,4 +1,5 @@ import os, sys + sys.path.insert(0, os.getcwd()) import argparse @@ -6,87 +7,125 @@ import argparse def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "base_model", help="The model which use it to train the dreambooth model", - default='', type=str + "base_model", + help="The model which use it to train the dreambooth model", + default="", + type=str, ) parser.add_argument( - "db_model", help="the dreambooth model you want to extract the locon", - default='', type=str + "db_model", + help="the dreambooth model you want to extract the locon", + default="", + type=str, ) parser.add_argument( - "output_name", help="the output model", - default='./out.pt', type=str + "output_name", help="the output model", default="./out.pt", type=str ) parser.add_argument( - "--is_v2", help="Your base/db model is sd v2 or not", - default=False, action="store_true" + "--is_v2", + help="Your base/db model is sd v2 or not", + default=False, + action="store_true", ) parser.add_argument( - "--device", help="Which device you want to use to extract the locon", - default='cpu', type=str + "--is_sdxl", + help="Your base/db model is sdxl or not", + default=False, + action="store_true", ) parser.add_argument( - "--mode", + "--device", + help="Which device you want to use to extract the locon", + default="cpu", + type=str, + ) + parser.add_argument( + "--mode", help=( - 'extraction mode, can be "fixed", "threshold", "ratio", "quantile". ' + 'extraction mode, can be "full", "fixed", "threshold", "ratio", "quantile". ' 'If not "fixed", network_dim and conv_dim will be ignored' ), - default='fixed', type=str + default="fixed", + type=str, ) parser.add_argument( - "--safetensors", help='use safetensors to save locon model', - default=False, action="store_true" + "--safetensors", + help="use safetensors to save locon model", + default=False, + action="store_true", ) parser.add_argument( - "--linear_dim", help="network dim for linear layer in fixed mode", - default=1, type=int + "--linear_dim", + help="network dim for linear layer in fixed mode", + default=1, + type=int, ) parser.add_argument( - "--conv_dim", help="network dim for conv layer in fixed mode", - default=1, type=int + "--conv_dim", + help="network dim for conv layer in fixed mode", + default=1, + type=int, ) parser.add_argument( - "--linear_threshold", help="singular value threshold for linear layer in threshold mode", - default=0., type=float + "--linear_threshold", + help="singular value threshold for linear layer in threshold mode", + default=0.0, + type=float, ) parser.add_argument( - "--conv_threshold", help="singular value threshold for conv layer in threshold mode", - default=0., type=float + "--conv_threshold", + help="singular value threshold for conv layer in threshold mode", + default=0.0, + type=float, ) parser.add_argument( - "--linear_ratio", help="singular ratio for linear layer in ratio mode", - default=0., type=float + "--linear_ratio", + help="singular ratio for linear layer in ratio mode", + default=0.0, + type=float, ) parser.add_argument( - "--conv_ratio", help="singular ratio for conv layer in ratio mode", - default=0., type=float + "--conv_ratio", + help="singular ratio for conv layer in ratio mode", + default=0.0, + type=float, ) parser.add_argument( - "--linear_quantile", help="singular value quantile for linear layer quantile mode", - default=1., type=float + "--linear_quantile", + help="singular value quantile for linear layer quantile mode", + default=1.0, + type=float, ) parser.add_argument( - "--conv_quantile", help="singular value quantile for conv layer quantile mode", - default=1., type=float + "--conv_quantile", + help="singular value quantile for conv layer quantile mode", + default=1.0, + type=float, ) parser.add_argument( - "--use_sparse_bias", help="enable sparse bias", - default=False, action="store_true" + "--use_sparse_bias", + help="enable sparse bias", + default=False, + action="store_true", ) parser.add_argument( - "--sparsity", help="sparsity for sparse bias", - default=0.98, type=float + "--sparsity", help="sparsity for sparse bias", default=0.98, type=float ) parser.add_argument( - "--disable_cp", help="don't use cp decomposition", - default=False, action="store_true" + "--disable_cp", + help="don't use cp decomposition", + default=False, + action="store_true", ) return parser.parse_args() + + ARGS = get_args() from lycoris.utils import extract_diff from lycoris.kohya.model_utils import load_models_from_stable_diffusion_checkpoint +from lycoris.kohya.sdxl_model_util import load_models_from_sdxl_checkpoint import torch from safetensors.torch import save_file @@ -94,36 +133,58 @@ from safetensors.torch import save_file def main(): args = ARGS - base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) - db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) - + if args.is_sdxl: + base = load_models_from_sdxl_checkpoint(None, args.base_model, args.device) + db = load_models_from_sdxl_checkpoint(None, args.db_model, args.device) + else: + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) + linear_mode_param = { - 'fixed': args.linear_dim, - 'threshold': args.linear_threshold, - 'ratio': args.linear_ratio, - 'quantile': args.linear_quantile, + "fixed": args.linear_dim, + "threshold": args.linear_threshold, + "ratio": args.linear_ratio, + "quantile": args.linear_quantile, + "full": None, }[args.mode] conv_mode_param = { - 'fixed': args.conv_dim, - 'threshold': args.conv_threshold, - 'ratio': args.conv_ratio, - 'quantile': args.conv_quantile, + "fixed": args.conv_dim, + "threshold": args.conv_threshold, + "ratio": args.conv_ratio, + "quantile": args.conv_quantile, + "full": None, }[args.mode] - + + if args.is_sdxl: + db_tes = [db[0], db[1]] + db_unet = db[3] + base_tes = [base[0], base[1]] + base_unet = base[3] + else: + db_tes = [db[0]] + db_unet = db[2] + base_tes = [base[0]] + base_unet = base[2] + state_dict = extract_diff( - base, db, + base_tes, + db_tes, + base_unet, + db_unet, args.mode, - linear_mode_param, conv_mode_param, - args.device, - args.use_sparse_bias, args.sparsity, - not args.disable_cp + linear_mode_param, + conv_mode_param, + args.device, + args.use_sparse_bias, + args.sparsity, + not args.disable_cp, ) - + if args.safetensors: save_file(state_dict, args.output_name) else: torch.save(state_dict, args.output_name) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/tools/merge_lycoris.py b/tools/merge_lycoris.py index b29c8dc..25f673d 100644 --- a/tools/merge_lycoris.py +++ b/tools/merge_lycoris.py @@ -1,4 +1,5 @@ import os, sys + sys.path.insert(0, os.getcwd()) import argparse @@ -6,80 +7,119 @@ import argparse def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "base_model", help="The model you want to merge with loha", - default='', type=str + "base_model", help="The model you want to merge with loha", default="", type=str ) parser.add_argument( - "lycoris_model", help="the lyco model you want to merge into sd model", - default='', type=str + "lycoris_model", + help="the lyco model you want to merge into sd model", + default="", + type=str, ) parser.add_argument( - "output_name", help="the output model", - default='./out.pt', type=str + "output_name", help="the output model", default="./out.pt", type=str ) parser.add_argument( - "--is_v2", help="Your base model is sd v2 or not", - default=False, action="store_true" + "--is_v2", + help="Your base model is sd v2 or not", + default=False, + action="store_true", ) parser.add_argument( - "--device", help="Which device you want to use to merge the weight", - default='cpu', type=str + "--is_sdxl", + help="Your base/db model is sdxl or not", + default=False, + action="store_true", ) parser.add_argument( - "--dtype", help='dtype to save', - default='float', type=str + "--device", + help="Which device you want to use to merge the weight", + default="cpu", + type=str, ) + parser.add_argument("--dtype", help="dtype to save", default="float", type=str) parser.add_argument( - "--weight", help='weight for the lyco model to merge', - default='1.0', type=float + "--weight", help="weight for the lyco model to merge", default="1.0", type=float ) return parser.parse_args() -ARGS = get_args() + + +args = ARGS = get_args() from lycoris.utils import merge from lycoris.kohya.model_utils import ( load_models_from_stable_diffusion_checkpoint, save_stable_diffusion_checkpoint, - load_file + load_file, +) +from lycoris.kohya.sdxl_model_util import ( + load_models_from_sdxl_checkpoint, + save_stable_diffusion_checkpoint as save_sdxl_checkpoint, ) import torch +@torch.no_grad() def main(): - base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model) - if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors': + if args.is_sdxl: + base = load_models_from_sdxl_checkpoint( + None, args.base_model, map_location=args.device + ) + else: + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + if ARGS.lycoris_model.rsplit(".", 1)[-1] == "safetensors": lyco = load_file(ARGS.lycoris_model) else: lyco = torch.load(ARGS.lycoris_model) - - dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat') + + dtype_str = ARGS.dtype.replace("fp", "float").replace("bf", "bfloat") dtype = { - 'float': torch.float, - 'float16': torch.float16, - 'float32': torch.float32, - 'float64': torch.float64, - 'bfloat': torch.bfloat16, - 'bfloat16': torch.bfloat16, + "float": torch.float, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat": torch.bfloat16, + "bfloat16": torch.bfloat16, }.get(dtype_str, None) if dtype is None: raise ValueError(f'Cannot Find the dtype "{dtype}"') - - merge( - base, - lyco, - ARGS.weight, - ARGS.device - ) - - save_stable_diffusion_checkpoint( - ARGS.is_v2, ARGS.output_name, - base[0], base[2], - None, 0, 0, dtype, - base[1] - ) + + if args.is_sdxl: + base_tes = [base[0], base[1]] + base_unet = base[3] + else: + base_tes = [base[0]] + base_unet = base[2] + + merge(base_tes, base_unet, lyco, ARGS.weight, ARGS.device) + + if args.is_sdxl: + save_sdxl_checkpoint( + ARGS.output_name, + base[0].cpu(), + base[1].cpu(), + base[3].cpu(), + 0, + 0, + None, + base[2], + getattr(base[1], "logit_scale", None), + dtype, + ) + else: + save_stable_diffusion_checkpoint( + ARGS.is_v2, + ARGS.output_name, + base[0].cpu(), + base[2].cpu(), + None, + 0, + 0, + dtype, + base[1], + ) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main()