mirror of https://github.com/bmaltais/kohya_ss
85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
import os, sys
|
|
sys.path.insert(0, os.getcwd())
|
|
import argparse
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"base_model", help="The model you want to merge with loha",
|
|
default='', type=str
|
|
)
|
|
parser.add_argument(
|
|
"lycoris_model", help="the lyco model you want to merge into sd model",
|
|
default='', type=str
|
|
)
|
|
parser.add_argument(
|
|
"output_name", help="the output model",
|
|
default='./out.pt', type=str
|
|
)
|
|
parser.add_argument(
|
|
"--is_v2", help="Your base model is sd v2 or not",
|
|
default=False, action="store_true"
|
|
)
|
|
parser.add_argument(
|
|
"--device", help="Which device you want to use to merge the weight",
|
|
default='cpu', type=str
|
|
)
|
|
parser.add_argument(
|
|
"--dtype", help='dtype to save',
|
|
default='float', type=str
|
|
)
|
|
parser.add_argument(
|
|
"--weight", help='weight for the lyco model to merge',
|
|
default='1.0', type=float
|
|
)
|
|
return parser.parse_args()
|
|
ARGS = get_args()
|
|
|
|
|
|
from lycoris_utils import merge
|
|
from lycoris.kohya_model_utils import (
|
|
load_models_from_stable_diffusion_checkpoint,
|
|
save_stable_diffusion_checkpoint,
|
|
load_file
|
|
)
|
|
|
|
import torch
|
|
|
|
|
|
def main():
|
|
base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model)
|
|
if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
|
|
lyco = load_file(ARGS.lycoris_model)
|
|
else:
|
|
lyco = torch.load(ARGS.lycoris_model)
|
|
|
|
dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat')
|
|
dtype = {
|
|
'float': torch.float,
|
|
'float16': torch.float16,
|
|
'float32': torch.float32,
|
|
'float64': torch.float64,
|
|
'bfloat': torch.bfloat16,
|
|
'bfloat16': torch.bfloat16,
|
|
}.get(dtype_str, None)
|
|
if dtype is None:
|
|
raise ValueError(f'Cannot Find the dtype "{dtype}"')
|
|
|
|
merge(
|
|
base,
|
|
lyco,
|
|
ARGS.weight,
|
|
ARGS.device
|
|
)
|
|
|
|
save_stable_diffusion_checkpoint(
|
|
ARGS.is_v2, ARGS.output_name,
|
|
base[0], base[2],
|
|
None, 0, 0, dtype,
|
|
base[1]
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |