Update extract LoRA and add sdpa

pull/1332/head
bmaltais 2023-07-25 20:11:32 -04:00
parent 101d2638e2
commit a9ec90c40a
7 changed files with 87 additions and 14 deletions

View File

@ -495,3 +495,5 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
- Fix issue with conv_block_dims and conv_block_alphas
- Fix 0 noise offset issue
- Implement Stop training button on LoRA
- Add support to extract LoRA from SDXL finetuned models
- Add support for PagedAdamW8bit and PagedLion8bit optimizer. Those require a new version of bitsandbytes so success on some systems may vary. I had to uninstall all my nvidia drivers and othe cuda toolkit install, delete all cuda variable references and re-install cuda toolkit v11.8.0 to get things to work... so not super easy.

View File

@ -0,0 +1,11 @@
# Command 1: merge_captions_to_metadata.py
$captionExtension = "--caption_extension=.txt"
$sourceDir1 = "d:\test\1_1960-1969"
$targetFile1 = "d:\test\1_1960-1969/meta_cap.json"
# Command 2: prepare_buckets_latents.py
$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json"
$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors"
./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path
./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path

View File

@ -89,6 +89,7 @@ def save_configuration(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
@ -209,6 +210,7 @@ def open_configuration(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
@ -326,6 +328,7 @@ def train_model(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
@ -575,6 +578,7 @@ def train_model(
gradient_checkpointing=gradient_checkpointing,
full_fp16=full_fp16,
xformers=xformers,
spda=sdpa,
# use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
@ -866,6 +870,7 @@ def finetune_tab(headless=False):
source_model.save_model_as,
basic_training.caption_extension,
advanced_training.xformers,
advanced_training.sdpa,
advanced_training.clip_skip,
advanced_training.save_state,
advanced_training.resume,

View File

@ -101,13 +101,14 @@ class AdvancedTraining:
# use_8bit_adam = gr.Checkbox(
# label='Use 8bit adam', value=False, visible=False
# )
self.xformers = gr.Checkbox(label='Use xformers', value=True)
self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
self.min_snr_gamma = gr.Slider(
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
)
with gr.Row():
self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
self.bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True
)

View File

@ -736,6 +736,10 @@ def run_cmd_advanced_training(**kwargs):
if xformers:
run_cmd += ' --xformers'
sdpa = kwargs.get('sdpa')
if sdpa:
run_cmd += ' --sdpa'
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers')
if persistent_data_loader_workers:
run_cmd += ' --persistent_data_loader_workers'
@ -856,3 +860,18 @@ def check_duplicate_filenames(folder_path, image_extension = ['.gif', '.png', '.
print(f"Current file: {full_path}")
else:
filenames[filename] = full_path
def is_file_writable(file_path):
if not os.path.exists(file_path):
# print(f"File '{file_path}' does not exist.")
return True
try:
log.warning(f"File '{file_path}' already exist... it will be overwritten...")
# Check if the file can be opened in write mode (which implies it's not open by another process)
with open(file_path, 'a'):
pass
return True
except IOError:
log.warning(f"File '{file_path}' can't be written to...")
return False

View File

@ -4,8 +4,8 @@ import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
is_file_writable
)
from library.custom_logging import setup_logging
@ -27,25 +27,31 @@ def extract_lora(
save_precision,
dim,
v2,
sdxl,
conv_dim,
clamp_quantile,
min_diff,
device,
):
# Check for caption_text_input
if model_tuned == '':
msgbox('Invalid finetuned model file')
log.info('Invalid finetuned model file')
return
if model_org == '':
msgbox('Invalid base model file')
log.info('Invalid base model file')
return
# Check if source model exist
if not os.path.isfile(model_tuned):
msgbox('The provided finetuned model is not a file')
log.info('The provided finetuned model is not a file')
return
if not os.path.isfile(model_org):
msgbox('The provided base model is not a file')
log.info('The provided base model is not a file')
return
if not is_file_writable(save_to):
return
run_cmd = (
@ -61,6 +67,10 @@ def extract_lora(
run_cmd += f' --conv_dim {conv_dim}'
if v2:
run_cmd += f' --v2'
if sdxl:
run_cmd += f' --sdxl'
run_cmd += f' --clamp_quantile {clamp_quantile}'
run_cmd += f' --min_diff {min_diff}'
log.info(run_cmd)
@ -160,7 +170,19 @@ def gradio_extract_lora_tab(headless=False):
step=1,
interactive=True,
)
clamp_quantile = gr.Number(
label='Clamp Quantile',
value=1,
interactive=True,
)
min_diff = gr.Number(
label='Minimum difference',
value=0.01,
interactive=True,
)
with gr.Row():
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
device = gr.Dropdown(
label='Device',
choices=[
@ -182,7 +204,10 @@ def gradio_extract_lora_tab(headless=False):
save_precision,
dim,
v2,
sdxl,
conv_dim,
clamp_quantile,
min_diff,
device,
],
show_progress=False,

View File

@ -12,10 +12,8 @@ import library.model_util as model_util
import library.sdxl_model_util as sdxl_model_util
import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-4
# CLAMP_QUANTILE = 1
# MIN_DIFF = 1e-2
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
@ -91,9 +89,9 @@ def svd(args):
diff = module_t.weight - module_o.weight
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > args.min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {args.min_diff}")
diff = diff.float()
diffs[lora_name] = diff
@ -149,7 +147,7 @@ def svd(args):
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, args.clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
@ -243,6 +241,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=1,
help="Quantile clamping value, float, (0-1). Defailt = 1",
)
parser.add_argument(
"--min_diff",
type=float,
default=1,
help="Minimum difference betwen finetuned model and base to consider them different enough to extract, float, (0-1). Defailt = 0.01",
)
return parser