diff --git a/.release b/.release index ef7f772..f9353ad 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.6.1 +v22.6.2 diff --git a/README.md b/README.md index 2465d90..e26bda7 100644 --- a/README.md +++ b/README.md @@ -503,6 +503,9 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b ## Change History +* 2024/02/17 (v22.6.2) +- Fix issue with Lora Extract GUI + * 2024/02/15 (v22.6.1) - Add support for multi-gpu parameters in the GUI under the "Parameters > Advanced" tab. - Significant rewrite of how parameters are created in the code. I hope I did not break anything in the process... Will make the code easier to update. diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 9aaad07..7dde723 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -95,7 +95,7 @@ def extract_lora( def gradio_extract_lora_tab(headless=False): def change_sdxl(sdxl): - return gr(visible=sdxl), gr(visible=sdxl) + return gr.Dropdown(visible=sdxl), gr.Dropdown(visible=sdxl) diff --git a/tools/extract_locon.py b/tools/extract_locon.py index a610766..b0bc57f 100644 --- a/tools/extract_locon.py +++ b/tools/extract_locon.py @@ -1,4 +1,5 @@ import os, sys + sys.path.insert(0, os.getcwd()) import argparse @@ -6,87 +7,125 @@ import argparse def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "base_model", help="The model which use it to train the dreambooth model", - default='', type=str + "base_model", + help="The model which use it to train the dreambooth model", + default="", + type=str, ) parser.add_argument( - "db_model", help="the dreambooth model you want to extract the locon", - default='', type=str + "db_model", + help="the dreambooth model you want to extract the locon", + default="", + type=str, ) parser.add_argument( - "output_name", help="the output model", - default='./out.pt', type=str + "output_name", help="the output model", default="./out.pt", type=str ) parser.add_argument( - "--is_v2", help="Your base/db model is sd v2 or not", - default=False, action="store_true" + "--is_v2", + help="Your base/db model is sd v2 or not", + default=False, + action="store_true", ) parser.add_argument( - "--device", help="Which device you want to use to extract the locon", - default='cpu', type=str + "--is_sdxl", + help="Your base/db model is sdxl or not", + default=False, + action="store_true", ) parser.add_argument( - "--mode", + "--device", + help="Which device you want to use to extract the locon", + default="cpu", + type=str, + ) + parser.add_argument( + "--mode", help=( - 'extraction mode, can be "fixed", "threshold", "ratio", "quantile". ' + 'extraction mode, can be "full", "fixed", "threshold", "ratio", "quantile". ' 'If not "fixed", network_dim and conv_dim will be ignored' ), - default='fixed', type=str + default="fixed", + type=str, ) parser.add_argument( - "--safetensors", help='use safetensors to save locon model', - default=True, action="store_true" + "--safetensors", + help="use safetensors to save locon model", + default=False, + action="store_true", ) parser.add_argument( - "--linear_dim", help="network dim for linear layer in fixed mode", - default=1, type=int + "--linear_dim", + help="network dim for linear layer in fixed mode", + default=1, + type=int, ) parser.add_argument( - "--conv_dim", help="network dim for conv layer in fixed mode", - default=1, type=int + "--conv_dim", + help="network dim for conv layer in fixed mode", + default=1, + type=int, ) parser.add_argument( - "--linear_threshold", help="singular value threshold for linear layer in threshold mode", - default=0., type=float + "--linear_threshold", + help="singular value threshold for linear layer in threshold mode", + default=0.0, + type=float, ) parser.add_argument( - "--conv_threshold", help="singular value threshold for conv layer in threshold mode", - default=0., type=float + "--conv_threshold", + help="singular value threshold for conv layer in threshold mode", + default=0.0, + type=float, ) parser.add_argument( - "--linear_ratio", help="singular ratio for linear layer in ratio mode", - default=0., type=float + "--linear_ratio", + help="singular ratio for linear layer in ratio mode", + default=0.0, + type=float, ) parser.add_argument( - "--conv_ratio", help="singular ratio for conv layer in ratio mode", - default=0., type=float + "--conv_ratio", + help="singular ratio for conv layer in ratio mode", + default=0.0, + type=float, ) parser.add_argument( - "--linear_quantile", help="singular value quantile for linear layer quantile mode", - default=1., type=float + "--linear_quantile", + help="singular value quantile for linear layer quantile mode", + default=1.0, + type=float, ) parser.add_argument( - "--conv_quantile", help="singular value quantile for conv layer quantile mode", - default=1., type=float + "--conv_quantile", + help="singular value quantile for conv layer quantile mode", + default=1.0, + type=float, ) parser.add_argument( - "--use_sparse_bias", help="enable sparse bias", - default=False, action="store_true" + "--use_sparse_bias", + help="enable sparse bias", + default=False, + action="store_true", ) parser.add_argument( - "--sparsity", help="sparsity for sparse bias", - default=0.98, type=float + "--sparsity", help="sparsity for sparse bias", default=0.98, type=float ) parser.add_argument( - "--disable_cp", help="don't use cp decomposition", - default=False, action="store_true" + "--disable_cp", + help="don't use cp decomposition", + default=False, + action="store_true", ) return parser.parse_args() + + ARGS = get_args() from lycoris.utils import extract_diff from lycoris.kohya.model_utils import load_models_from_stable_diffusion_checkpoint +from lycoris.kohya.sdxl_model_util import load_models_from_sdxl_checkpoint import torch from safetensors.torch import save_file @@ -94,36 +133,58 @@ from safetensors.torch import save_file def main(): args = ARGS - base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) - db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) - + if args.is_sdxl: + base = load_models_from_sdxl_checkpoint(None, args.base_model, args.device) + db = load_models_from_sdxl_checkpoint(None, args.db_model, args.device) + else: + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) + linear_mode_param = { - 'fixed': args.linear_dim, - 'threshold': args.linear_threshold, - 'ratio': args.linear_ratio, - 'quantile': args.linear_quantile, + "fixed": args.linear_dim, + "threshold": args.linear_threshold, + "ratio": args.linear_ratio, + "quantile": args.linear_quantile, + "full": None, }[args.mode] conv_mode_param = { - 'fixed': args.conv_dim, - 'threshold': args.conv_threshold, - 'ratio': args.conv_ratio, - 'quantile': args.conv_quantile, + "fixed": args.conv_dim, + "threshold": args.conv_threshold, + "ratio": args.conv_ratio, + "quantile": args.conv_quantile, + "full": None, }[args.mode] - + + if args.is_sdxl: + db_tes = [db[0], db[1]] + db_unet = db[3] + base_tes = [base[0], base[1]] + base_unet = base[3] + else: + db_tes = [db[0]] + db_unet = db[2] + base_tes = [base[0]] + base_unet = base[2] + state_dict = extract_diff( - base, db, + base_tes, + db_tes, + base_unet, + db_unet, args.mode, - linear_mode_param, conv_mode_param, - args.device, - args.use_sparse_bias, args.sparsity, - not args.disable_cp + linear_mode_param, + conv_mode_param, + args.device, + args.use_sparse_bias, + args.sparsity, + not args.disable_cp, ) - + if args.safetensors: save_file(state_dict, args.output_name) else: torch.save(state_dict, args.output_name) -if __name__ == '__main__': +if __name__ == "__main__": main() \ No newline at end of file