mirror of https://github.com/bmaltais/kohya_ss
Update extract LoRA and add sdpa
parent
101d2638e2
commit
a9ec90c40a
|
|
@ -494,4 +494,6 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
|
||||||
- Add missing LR number of cycles and LR power to Dreambooth and TI scripts
|
- Add missing LR number of cycles and LR power to Dreambooth and TI scripts
|
||||||
- Fix issue with conv_block_dims and conv_block_alphas
|
- Fix issue with conv_block_dims and conv_block_alphas
|
||||||
- Fix 0 noise offset issue
|
- Fix 0 noise offset issue
|
||||||
- Implement Stop training button on LoRA
|
- Implement Stop training button on LoRA
|
||||||
|
- Add support to extract LoRA from SDXL finetuned models
|
||||||
|
- Add support for PagedAdamW8bit and PagedLion8bit optimizer. Those require a new version of bitsandbytes so success on some systems may vary. I had to uninstall all my nvidia drivers and othe cuda toolkit install, delete all cuda variable references and re-install cuda toolkit v11.8.0 to get things to work... so not super easy.
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Command 1: merge_captions_to_metadata.py
|
||||||
|
$captionExtension = "--caption_extension=.txt"
|
||||||
|
$sourceDir1 = "d:\test\1_1960-1969"
|
||||||
|
$targetFile1 = "d:\test\1_1960-1969/meta_cap.json"
|
||||||
|
|
||||||
|
# Command 2: prepare_buckets_latents.py
|
||||||
|
$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json"
|
||||||
|
$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors"
|
||||||
|
|
||||||
|
./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path
|
||||||
|
./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path
|
||||||
|
|
@ -89,6 +89,7 @@ def save_configuration(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
sdpa,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
|
|
@ -209,6 +210,7 @@ def open_configuration(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
sdpa,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
|
|
@ -326,6 +328,7 @@ def train_model(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
sdpa,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
|
|
@ -575,6 +578,7 @@ def train_model(
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
full_fp16=full_fp16,
|
full_fp16=full_fp16,
|
||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
|
spda=sdpa,
|
||||||
# use_8bit_adam=use_8bit_adam,
|
# use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
|
|
@ -866,6 +870,7 @@ def finetune_tab(headless=False):
|
||||||
source_model.save_model_as,
|
source_model.save_model_as,
|
||||||
basic_training.caption_extension,
|
basic_training.caption_extension,
|
||||||
advanced_training.xformers,
|
advanced_training.xformers,
|
||||||
|
advanced_training.sdpa,
|
||||||
advanced_training.clip_skip,
|
advanced_training.clip_skip,
|
||||||
advanced_training.save_state,
|
advanced_training.save_state,
|
||||||
advanced_training.resume,
|
advanced_training.resume,
|
||||||
|
|
|
||||||
|
|
@ -101,13 +101,14 @@ class AdvancedTraining:
|
||||||
# use_8bit_adam = gr.Checkbox(
|
# use_8bit_adam = gr.Checkbox(
|
||||||
# label='Use 8bit adam', value=False, visible=False
|
# label='Use 8bit adam', value=False, visible=False
|
||||||
# )
|
# )
|
||||||
self.xformers = gr.Checkbox(label='Use xformers', value=True)
|
self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
|
||||||
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||||
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
self.min_snr_gamma = gr.Slider(
|
self.min_snr_gamma = gr.Slider(
|
||||||
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
|
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
|
||||||
self.bucket_no_upscale = gr.Checkbox(
|
self.bucket_no_upscale = gr.Checkbox(
|
||||||
label="Don't upscale bucket resolution", value=True
|
label="Don't upscale bucket resolution", value=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -735,6 +735,10 @@ def run_cmd_advanced_training(**kwargs):
|
||||||
xformers = kwargs.get('xformers')
|
xformers = kwargs.get('xformers')
|
||||||
if xformers:
|
if xformers:
|
||||||
run_cmd += ' --xformers'
|
run_cmd += ' --xformers'
|
||||||
|
|
||||||
|
sdpa = kwargs.get('sdpa')
|
||||||
|
if sdpa:
|
||||||
|
run_cmd += ' --sdpa'
|
||||||
|
|
||||||
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers')
|
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers')
|
||||||
if persistent_data_loader_workers:
|
if persistent_data_loader_workers:
|
||||||
|
|
@ -856,3 +860,18 @@ def check_duplicate_filenames(folder_path, image_extension = ['.gif', '.png', '.
|
||||||
print(f"Current file: {full_path}")
|
print(f"Current file: {full_path}")
|
||||||
else:
|
else:
|
||||||
filenames[filename] = full_path
|
filenames[filename] = full_path
|
||||||
|
|
||||||
|
def is_file_writable(file_path):
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
# print(f"File '{file_path}' does not exist.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.warning(f"File '{file_path}' already exist... it will be overwritten...")
|
||||||
|
# Check if the file can be opened in write mode (which implies it's not open by another process)
|
||||||
|
with open(file_path, 'a'):
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
except IOError:
|
||||||
|
log.warning(f"File '{file_path}' can't be written to...")
|
||||||
|
return False
|
||||||
|
|
@ -4,8 +4,8 @@ import subprocess
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
from .common_gui import (
|
||||||
get_saveasfilename_path,
|
get_saveasfilename_path,
|
||||||
get_any_file_path,
|
|
||||||
get_file_path,
|
get_file_path,
|
||||||
|
is_file_writable
|
||||||
)
|
)
|
||||||
|
|
||||||
from library.custom_logging import setup_logging
|
from library.custom_logging import setup_logging
|
||||||
|
|
@ -27,25 +27,31 @@ def extract_lora(
|
||||||
save_precision,
|
save_precision,
|
||||||
dim,
|
dim,
|
||||||
v2,
|
v2,
|
||||||
|
sdxl,
|
||||||
conv_dim,
|
conv_dim,
|
||||||
|
clamp_quantile,
|
||||||
|
min_diff,
|
||||||
device,
|
device,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model_tuned == '':
|
if model_tuned == '':
|
||||||
msgbox('Invalid finetuned model file')
|
log.info('Invalid finetuned model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_org == '':
|
if model_org == '':
|
||||||
msgbox('Invalid base model file')
|
log.info('Invalid base model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(model_tuned):
|
if not os.path.isfile(model_tuned):
|
||||||
msgbox('The provided finetuned model is not a file')
|
log.info('The provided finetuned model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(model_org):
|
if not os.path.isfile(model_org):
|
||||||
msgbox('The provided base model is not a file')
|
log.info('The provided base model is not a file')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not is_file_writable(save_to):
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = (
|
run_cmd = (
|
||||||
|
|
@ -61,6 +67,10 @@ def extract_lora(
|
||||||
run_cmd += f' --conv_dim {conv_dim}'
|
run_cmd += f' --conv_dim {conv_dim}'
|
||||||
if v2:
|
if v2:
|
||||||
run_cmd += f' --v2'
|
run_cmd += f' --v2'
|
||||||
|
if sdxl:
|
||||||
|
run_cmd += f' --sdxl'
|
||||||
|
run_cmd += f' --clamp_quantile {clamp_quantile}'
|
||||||
|
run_cmd += f' --min_diff {min_diff}'
|
||||||
|
|
||||||
log.info(run_cmd)
|
log.info(run_cmd)
|
||||||
|
|
||||||
|
|
@ -160,7 +170,19 @@ def gradio_extract_lora_tab(headless=False):
|
||||||
step=1,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
clamp_quantile = gr.Number(
|
||||||
|
label='Clamp Quantile',
|
||||||
|
value=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
min_diff = gr.Number(
|
||||||
|
label='Minimum difference',
|
||||||
|
value=0.01,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
||||||
|
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
|
||||||
device = gr.Dropdown(
|
device = gr.Dropdown(
|
||||||
label='Device',
|
label='Device',
|
||||||
choices=[
|
choices=[
|
||||||
|
|
@ -182,7 +204,10 @@ def gradio_extract_lora_tab(headless=False):
|
||||||
save_precision,
|
save_precision,
|
||||||
dim,
|
dim,
|
||||||
v2,
|
v2,
|
||||||
|
sdxl,
|
||||||
conv_dim,
|
conv_dim,
|
||||||
|
clamp_quantile,
|
||||||
|
min_diff,
|
||||||
device,
|
device,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,8 @@ import library.model_util as model_util
|
||||||
import library.sdxl_model_util as sdxl_model_util
|
import library.sdxl_model_util as sdxl_model_util
|
||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
# CLAMP_QUANTILE = 1
|
||||||
CLAMP_QUANTILE = 0.99
|
# MIN_DIFF = 1e-2
|
||||||
MIN_DIFF = 1e-4
|
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
|
@ -91,9 +89,9 @@ def svd(args):
|
||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight - module_o.weight
|
||||||
|
|
||||||
# Text Encoder might be same
|
# Text Encoder might be same
|
||||||
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
|
if not text_encoder_different and torch.max(torch.abs(diff)) > args.min_diff:
|
||||||
text_encoder_different = True
|
text_encoder_different = True
|
||||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
|
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {args.min_diff}")
|
||||||
|
|
||||||
diff = diff.float()
|
diff = diff.float()
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
|
|
@ -149,7 +147,7 @@ def svd(args):
|
||||||
Vh = Vh[:rank, :]
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
hi_val = torch.quantile(dist, args.clamp_quantile)
|
||||||
low_val = -hi_val
|
low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
|
|
@ -243,6 +241,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
parser.add_argument(
|
||||||
|
"--clamp_quantile",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="Quantile clamping value, float, (0-1). Defailt = 1",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_diff",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="Minimum difference betwen finetuned model and base to consider them different enough to extract, float, (0-1). Defailt = 0.01",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue