mirror of https://github.com/bmaltais/kohya_ss
Update extract LoRA and add sdpa
parent
101d2638e2
commit
a9ec90c40a
|
|
@ -494,4 +494,6 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
|
|||
- Add missing LR number of cycles and LR power to Dreambooth and TI scripts
|
||||
- Fix issue with conv_block_dims and conv_block_alphas
|
||||
- Fix 0 noise offset issue
|
||||
- Implement Stop training button on LoRA
|
||||
- Implement Stop training button on LoRA
|
||||
- Add support to extract LoRA from SDXL finetuned models
|
||||
- Add support for PagedAdamW8bit and PagedLion8bit optimizer. Those require a new version of bitsandbytes so success on some systems may vary. I had to uninstall all my nvidia drivers and othe cuda toolkit install, delete all cuda variable references and re-install cuda toolkit v11.8.0 to get things to work... so not super easy.
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# Command 1: merge_captions_to_metadata.py
|
||||
$captionExtension = "--caption_extension=.txt"
|
||||
$sourceDir1 = "d:\test\1_1960-1969"
|
||||
$targetFile1 = "d:\test\1_1960-1969/meta_cap.json"
|
||||
|
||||
# Command 2: prepare_buckets_latents.py
|
||||
$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json"
|
||||
$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors"
|
||||
|
||||
./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path
|
||||
./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path
|
||||
|
|
@ -89,6 +89,7 @@ def save_configuration(
|
|||
caption_extension,
|
||||
# use_8bit_adam,
|
||||
xformers,
|
||||
sdpa,
|
||||
clip_skip,
|
||||
save_state,
|
||||
resume,
|
||||
|
|
@ -209,6 +210,7 @@ def open_configuration(
|
|||
caption_extension,
|
||||
# use_8bit_adam,
|
||||
xformers,
|
||||
sdpa,
|
||||
clip_skip,
|
||||
save_state,
|
||||
resume,
|
||||
|
|
@ -326,6 +328,7 @@ def train_model(
|
|||
caption_extension,
|
||||
# use_8bit_adam,
|
||||
xformers,
|
||||
sdpa,
|
||||
clip_skip,
|
||||
save_state,
|
||||
resume,
|
||||
|
|
@ -575,6 +578,7 @@ def train_model(
|
|||
gradient_checkpointing=gradient_checkpointing,
|
||||
full_fp16=full_fp16,
|
||||
xformers=xformers,
|
||||
spda=sdpa,
|
||||
# use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||
|
|
@ -866,6 +870,7 @@ def finetune_tab(headless=False):
|
|||
source_model.save_model_as,
|
||||
basic_training.caption_extension,
|
||||
advanced_training.xformers,
|
||||
advanced_training.sdpa,
|
||||
advanced_training.clip_skip,
|
||||
advanced_training.save_state,
|
||||
advanced_training.resume,
|
||||
|
|
|
|||
|
|
@ -101,13 +101,14 @@ class AdvancedTraining:
|
|||
# use_8bit_adam = gr.Checkbox(
|
||||
# label='Use 8bit adam', value=False, visible=False
|
||||
# )
|
||||
self.xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
|
||||
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
self.min_snr_gamma = gr.Slider(
|
||||
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
|
||||
)
|
||||
with gr.Row():
|
||||
self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
|
||||
self.bucket_no_upscale = gr.Checkbox(
|
||||
label="Don't upscale bucket resolution", value=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -735,6 +735,10 @@ def run_cmd_advanced_training(**kwargs):
|
|||
xformers = kwargs.get('xformers')
|
||||
if xformers:
|
||||
run_cmd += ' --xformers'
|
||||
|
||||
sdpa = kwargs.get('sdpa')
|
||||
if sdpa:
|
||||
run_cmd += ' --sdpa'
|
||||
|
||||
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers')
|
||||
if persistent_data_loader_workers:
|
||||
|
|
@ -856,3 +860,18 @@ def check_duplicate_filenames(folder_path, image_extension = ['.gif', '.png', '.
|
|||
print(f"Current file: {full_path}")
|
||||
else:
|
||||
filenames[filename] = full_path
|
||||
|
||||
def is_file_writable(file_path):
|
||||
if not os.path.exists(file_path):
|
||||
# print(f"File '{file_path}' does not exist.")
|
||||
return True
|
||||
|
||||
try:
|
||||
log.warning(f"File '{file_path}' already exist... it will be overwritten...")
|
||||
# Check if the file can be opened in write mode (which implies it's not open by another process)
|
||||
with open(file_path, 'a'):
|
||||
pass
|
||||
return True
|
||||
except IOError:
|
||||
log.warning(f"File '{file_path}' can't be written to...")
|
||||
return False
|
||||
|
|
@ -4,8 +4,8 @@ import subprocess
|
|||
import os
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
is_file_writable
|
||||
)
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
|
|
@ -27,25 +27,31 @@ def extract_lora(
|
|||
save_precision,
|
||||
dim,
|
||||
v2,
|
||||
sdxl,
|
||||
conv_dim,
|
||||
clamp_quantile,
|
||||
min_diff,
|
||||
device,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == '':
|
||||
msgbox('Invalid finetuned model file')
|
||||
log.info('Invalid finetuned model file')
|
||||
return
|
||||
|
||||
if model_org == '':
|
||||
msgbox('Invalid base model file')
|
||||
log.info('Invalid base model file')
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(model_tuned):
|
||||
msgbox('The provided finetuned model is not a file')
|
||||
log.info('The provided finetuned model is not a file')
|
||||
return
|
||||
|
||||
if not os.path.isfile(model_org):
|
||||
msgbox('The provided base model is not a file')
|
||||
log.info('The provided base model is not a file')
|
||||
return
|
||||
|
||||
if not is_file_writable(save_to):
|
||||
return
|
||||
|
||||
run_cmd = (
|
||||
|
|
@ -61,6 +67,10 @@ def extract_lora(
|
|||
run_cmd += f' --conv_dim {conv_dim}'
|
||||
if v2:
|
||||
run_cmd += f' --v2'
|
||||
if sdxl:
|
||||
run_cmd += f' --sdxl'
|
||||
run_cmd += f' --clamp_quantile {clamp_quantile}'
|
||||
run_cmd += f' --min_diff {min_diff}'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
|
|
@ -160,7 +170,19 @@ def gradio_extract_lora_tab(headless=False):
|
|||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
clamp_quantile = gr.Number(
|
||||
label='Clamp Quantile',
|
||||
value=1,
|
||||
interactive=True,
|
||||
)
|
||||
min_diff = gr.Number(
|
||||
label='Minimum difference',
|
||||
value=0.01,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
||||
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
|
||||
device = gr.Dropdown(
|
||||
label='Device',
|
||||
choices=[
|
||||
|
|
@ -182,7 +204,10 @@ def gradio_extract_lora_tab(headless=False):
|
|||
save_precision,
|
||||
dim,
|
||||
v2,
|
||||
sdxl,
|
||||
conv_dim,
|
||||
clamp_quantile,
|
||||
min_diff,
|
||||
device,
|
||||
],
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -12,10 +12,8 @@ import library.model_util as model_util
|
|||
import library.sdxl_model_util as sdxl_model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-4
|
||||
|
||||
# CLAMP_QUANTILE = 1
|
||||
# MIN_DIFF = 1e-2
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
|
|
@ -91,9 +89,9 @@ def svd(args):
|
|||
diff = module_t.weight - module_o.weight
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > args.min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {args.min_diff}")
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
|
|
@ -149,7 +147,7 @@ def svd(args):
|
|||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
hi_val = torch.quantile(dist, args.clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
|
|
@ -243,6 +241,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=1,
|
||||
help="Quantile clamping value, float, (0-1). Defailt = 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_diff",
|
||||
type=float,
|
||||
default=1,
|
||||
help="Minimum difference betwen finetuned model and base to consider them different enough to extract, float, (0-1). Defailt = 0.01",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue