Add AUTO-compatible metadata, UI improvements

pull/1356/head
d8ahazard 2023-09-21 10:45:13 -05:00
parent bd162d689d
commit 8030472b60
10 changed files with 527 additions and 67 deletions

View File

@ -1,3 +1,4 @@
import datetime
import json
import logging
import os
@ -9,8 +10,9 @@ from pydantic import BaseModel
from dreambooth import shared # noqa
from dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dataclasses.ss_model_spec import build_metadata
from dreambooth.utils.image_utils import get_scheduler_names # noqa
from dreambooth.utils.utils import list_attention
from dreambooth.utils.utils import list_attention, select_precision, select_attention
# Keys to save, replacing our dumb __init__ method
save_keys = []
@ -51,23 +53,22 @@ class DreamboothConfig(BaseModel):
infer_ema: bool = False
initial_revision: int = 0
input_pertubation: bool = True
learning_rate: float = 5e-6
learning_rate: float = 2e-6
learning_rate_min: float = 1e-6
lifetime_revision: int = 0
lora_learning_rate: float = 1e-4
lora_model_name: str = ""
lora_txt_learning_rate: float = 5e-5
lora_txt_rank: int = 4
lora_txt_weight: float = 1.0
lora_unet_rank: int = 4
lora_weight: float = 1.0
lora_weight: float = 0.8
lora_use_buggy_requires_grad: bool = False
lr_cycles: int = 1
lr_factor: float = 0.5
lr_power: float = 1.0
lr_scale_pos: float = 0.5
lr_scheduler: str = "constant_with_warmup"
lr_warmup_steps: int = 0
lr_warmup_steps: int = 500
max_token_length: int = 75
min_snr_gamma: float = 0.0
mixed_precision: str = "fp16"
@ -143,6 +144,12 @@ class DreamboothConfig(BaseModel):
super().__init__(**kwargs)
model_name = sanitize_name(model_name)
if "attention" not in kwargs:
self.attention = select_attention()
if "mixed_precision" not in kwargs:
self.mixed_precision = select_precision()
if "models_path" in kwargs:
models_path = kwargs["models_path"]
print(f"Using models path: {models_path}")
@ -157,7 +164,6 @@ class DreamboothConfig(BaseModel):
if not self.use_lora:
self.lora_model_name = ""
model_dir = os.path.join(models_path, model_name)
# print(f"Model dir set to: {model_dir}")
working_dir = os.path.join(model_dir, "working")
@ -325,6 +331,85 @@ class DreamboothConfig(BaseModel):
return os.path.join(self.model_dir, "working")
return self.pretrained_model_name_or_path
def export_ss_metadata(self, state_dict=None):
params = {}
token_counts_path = os.path.join(self.model_dir, "token_counts.json")
tags = None
if os.path.exists(token_counts_path):
with open(token_counts_path, "r") as f:
tags = json.load(f)
base_meta = build_metadata(
state_dict=state_dict,
v2 = "v2x" in self.model_type,
v_parameterization = self.model_type == "v2x",
sdxl = self.model_type == "SDXL",
lora=self.use_lora,
textual_inversion=False,
timestamp=datetime.datetime.now().timestamp(),
reso=(self.resolution, self.resolution),
tags=tags,
clip_skip=self.clip_skip
)
mappings = {
'cache_latents': 'ss_cache_latents',
'clip_skip': 'ss_clip_skip',
'epoch': 'ss_epoch',
'gradient_accumulation_steps': 'ss_gradient_accumulation_steps',
'gradient_checkpointing': 'ss_gradient_checkpointing',
'learning_rate': 'ss_learning_rate',
'lr_scheduler': 'ss_lr_scheduler',
'lr_warmup_steps': 'ss_lr_warmup_steps',
'max_token_length': 'ss_max_token_length',
'min_snr_gamma': 'ss_min_snr_gamma',
'mixed_precision': 'ss_mixed_precision',
'optimizer': 'ss_optimizer',
'prior_loss_weight': 'ss_prior_loss_weight',
'resolution': 'ss_resolution',
'src': 'ss_sd_model_name',
'shuffle_tags': 'ss_shuffle_captions'
}
for key, value in mappings.items():
if hasattr(self, key):
if value == "ss_resolution":
res = getattr(self, key)
params[value] = f"({res}, {res})"
params["modelspec.resolution"] = f"{res}x{res}"
elif value == "ss_sd_model_name":
model_name = getattr(self, key)
# Ensure model_name is only the model name, not the full path
model_name = os.path.basename(model_name)
params[value] = model_name
params[value] = getattr(self, key)
# Enumerate all params convert each one to a string
for key, value in params.items():
# If the value is not a string, convert it to one
if not isinstance(value, str):
value = str(value)
if key not in base_meta:
base_meta[key] = value
for key, value in base_meta.items():
if isinstance(value, Dict):
base_meta[key] = json.dumps(value)
elif not isinstance(value, str):
base_meta[key] = str(value)
ss_sd_model_name = base_meta.get("ss_sd_model_name", "")
if ss_sd_model_name != "":
# Make SURE we don't have a full path here
ss_sd_model_name = os.path.basename(ss_sd_model_name)
if "/" in ss_sd_model_name:
ss_sd_model_name = ss_sd_model_name.split("/")[-1]
if "\\" in ss_sd_model_name:
ss_sd_model_name = ss_sd_model_name.split("\\")[-1]
base_meta["ss_sd_model_name"] = ss_sd_model_name
if self.model_type == "SDXL":
base_meta["sd_version"] = "SDXL"
if "v2x" in self.model_type:
base_meta["sd_version"] = "V2"
if "v1x" in self.model_type:
base_meta["sd_version"] = "V1"
return base_meta
def concepts_from_file(concepts_path: str):
concepts = []

View File

@ -0,0 +1,216 @@
import datetime
import hashlib
from io import BytesIO
from typing import Optional, Union, Tuple
import safetensors.torch
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.license": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
}
# 別に使うやつだけ定義
MODELSPEC_TITLE = "modelspec.title"
ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata(
state_dict: Optional[dict],
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
metadata["ss_base_model_version"] = arch
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
metadata["ss_tag_frequency"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
return metadata

View File

@ -1,3 +1,6 @@
import json
import os
import safetensors.torch
import torch
@ -86,7 +89,8 @@ secondary_keys = [
]
def convert_diffusers_to_kohya_lora(model_dict, path):
def convert_diffusers_to_kohya_lora(path, metadata, alpha=0.8):
model_dict = safetensors.torch.load_file(path)
new_model_dict = {}
alpha_keys = []
# Replace the things
@ -109,6 +113,9 @@ def convert_diffusers_to_kohya_lora(model_dict, path):
# Add missing alpha keys
for k in alpha_keys:
if k not in new_model_dict:
print(f"Adding missing alpha key {k}")
new_model_dict[k] = torch.tensor(0.8)
safetensors.torch.save_file(new_model_dict, path)
new_model_dict[k] = torch.tensor(alpha)
conv_path = path.replace(".safetensors", "_auto.safetensors")
safetensors.torch.save_file(new_model_dict, conv_path, metadata=metadata)
# Delete the file at path, move the new file to path
os.remove(path)
os.rename(conv_path, path)

View File

@ -506,7 +506,8 @@ def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_model
printi(f"Saving checkpoint to {checkpoint_path}...", log=log)
if save_safetensors:
safe_dict, json_dict = split_dict(state_dict, pbar)
safetensors.torch.save_file(safe_dict, checkpoint_path, json_dict)
meta = config.export_ss_metadata()
safetensors.torch.save_file(safe_dict, checkpoint_path, meta)
else:
torch.save(state_dict, checkpoint_path)
cfg_file = None

View File

@ -375,7 +375,8 @@ def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_model
printi(f"Saving checkpoint to {checkpoint_path}...", log=log)
if save_safetensors:
save_file(state_dict, checkpoint_path)
meta = config.export_ss_metadata()
save_file(state_dict, checkpoint_path, meta)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, checkpoint_path)

View File

@ -14,6 +14,7 @@ from contextlib import ExitStack
from decimal import Decimal
from pathlib import Path
import safetensors.torch
import tomesd
import torch
import torch.backends.cuda
@ -40,7 +41,7 @@ from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataclasses.train_result import TrainResult
from dreambooth.dataset.bucket_sampler import BucketSampler
from dreambooth.dataset.sample_dataset import SampleDataset
from dreambooth.deis_velocity import get_velocity, compute_snr
from dreambooth.deis_velocity import get_velocity
from dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
@ -55,9 +56,9 @@ from dreambooth.utils.model_utils import (
disable_safe_unpickle,
enable_safe_unpickle,
xformerify,
torch2ify, unet_attn_processors_state_dict,
torch2ify, unet_attn_processors_state_dict
)
from dreambooth.utils.text_utils import encode_hidden_state
from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
patch_accelerator_for_fp16_training)
from dreambooth.webhook import send_training_update
@ -278,6 +279,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
args, class_gen_method=class_gen_method, accelerator=accelerator, ui=False, pbar=pbar2
)
save_token_counts(args, instance_prompts, 10)
if status.interrupted:
result.msg = "Training interrupted."
stop_profiler(profiler)
@ -1050,7 +1053,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
else:
model_dir = shared.models_path
loras_dir = os.path.join(model_dir, "Lora")
delete_tmp_lora = False
# Update the temp path if we just need to save an image
if save_image:
logger.debug("Save image is set.")
@ -1058,9 +1061,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if not save_lora:
logger.debug("Saving lora weights instead of checkpoint, using temp dir.")
save_lora = True
delete_tmp_lora = True
save_checkpoint = False
save_diffusers = False
loras_dir = f"{loras_dir}_temp"
os.makedirs(loras_dir, exist_ok=True)
elif not save_diffusers:
logger.debug("Saving checkpoint, using temp dir.")
@ -1081,6 +1084,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors")
with accelerator.autocast(), torch.inference_mode():
def lora_save_function(weights, filename):
metadata = args.export_ss_metadata()
logger.debug(f"Saving lora to {filename}")
safetensors.torch.save_file(weights, filename, metadata=metadata)
if save_lora:
# TODO: Add a version for the lora model?
pbar2.reset(1)
@ -1102,7 +1111,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
weight_name=lora_save_file,
safe_serialization=True,
save_function=convert_diffusers_to_kohya_lora
save_function=lora_save_function
)
scheduler_args = {}
@ -1123,8 +1132,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
weight_name=lora_save_file,
safe_serialization=True,
save_function=convert_diffusers_to_kohya_lora
safe_serialization=True
)
s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
s_pipeline.scheduler.config)
@ -1221,6 +1229,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False)
if args.use_lora:
s_pipeline.load_lora_weights(lora_save_file)
try:
s_pipeline.enable_vae_tiling()
s_pipeline.enable_vae_slicing()
@ -1342,6 +1351,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
update_status({"images": last_samples, "prompts": last_prompts})
pbar2.update()
if args.cache_latents:
printm("Unloading vae.")
del vae
@ -1362,6 +1372,15 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda")
cleanup()
# Save the lora weights if we are saving the model
if os.path.isfile(lora_save_file) and not delete_tmp_lora:
meta = args.export_ss_metadata()
convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight)
else:
if os.path.isfile(lora_save_file):
os.remove(lora_save_file)
printm("Completed saving weights.")
pbar2.reset()

View File

@ -950,7 +950,7 @@ def create_model(
res = 512
elif model_type == "v2x":
res = 768
elif model_type == "sdxl":
elif model_type == "SDXL":
res = 1024
sh = None
try:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import collections
import json
import logging
import os
import re
@ -294,4 +295,42 @@ def torch2ify(unet):
return unet
def is_xformers_available():
pass
pass
def read_metadata_from_safetensors(filename):
with open(filename, mode="rb") as file:
# Read metadata length
metadata_len = int.from_bytes(file.read(8), "little")
# Read the metadata based on its length
json_data = file.read(metadata_len).decode('utf-8')
res = {}
# Check if it's a valid JSON string
try:
json_obj = json.loads(json_data)
except json.JSONDecodeError:
return res
# Extract metadata
metadata = json_obj.get("__metadata__", {})
if not isinstance(metadata, dict):
return res
# Process the metadata to handle nested JSON strings
for k, v in metadata.items():
# if not isinstance(v, str):
# raise ValueError("All values in __metadata__ must be strings")
# If the string value looks like a JSON string, attempt to parse it
if v.startswith('{'):
try:
res[k] = json.loads(v)
except Exception:
res[k] = v
else:
res[k] = v
return res

View File

@ -230,6 +230,10 @@ def list_attention():
else:
return ["default"]
def select_attention():
attentions = list_attention()
# Return the last element
return attentions[-1]
def list_precisions():
precisions = ["no", "fp16"]
@ -241,6 +245,11 @@ def list_precisions():
return precisions
def select_precision():
precisions = list_precisions()
# Return the last element
return precisions[-1]
def list_schedulers():
return [

View File

@ -1,4 +1,5 @@
import importlib
import json
import time
from typing import List
@ -45,7 +46,7 @@ from dreambooth.utils.utils import (
wrap_gpu_call,
printm,
list_optimizer,
list_schedulers,
list_schedulers, select_precision, select_attention,
)
from dreambooth.webhook import save_and_test_webhook
from helpers.log_parser import LogParser
@ -60,6 +61,45 @@ delete_symbol = "\U0001F5D1" # 🗑️
update_symbol = "\U0001F51D" # 🠝
log_parser = LogParser()
def read_metadata_from_safetensors(filename):
with open(filename, mode="rb") as file:
# Read metadata length
metadata_len = int.from_bytes(file.read(8), "little")
# Read the metadata based on its length
json_data = file.read(metadata_len).decode('utf-8')
res = {}
# Check if it's a valid JSON string
try:
json_obj = json.loads(json_data)
except json.JSONDecodeError:
return res
# Extract metadata
metadata = json_obj.get("__metadata__", {})
if not isinstance(metadata, dict):
return res
# Process the metadata to handle nested JSON strings
for k, v in metadata.items():
# if not isinstance(v, str):
# raise ValueError("All values in __metadata__ must be strings")
# If the string value looks like a JSON string, attempt to parse it
if v.startswith('{'):
try:
res[k] = json.loads(v)
except Exception:
res[k] = v
else:
res[k] = v
return res
def get_sd_models():
sd_models.list_models()
@ -354,7 +394,8 @@ def on_ui_tabs():
db_new_model_shared_src = gr.Dropdown(
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
choices=sorted(get_shared_models()),
value=""
value="",
visible=False
)
create_refresh_button(
db_new_model_shared_src,
@ -388,7 +429,7 @@ def on_ui_tabs():
value=False,
visible=False,
)
db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False)
db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False)
db_train_inpainting = gr.Checkbox(
label="Train Inpainting Model",
value=False,
@ -471,7 +512,7 @@ def on_ui_tabs():
label="Learning Rate", value=2e-6
)
db_txt_learning_rate = gr.Number(
label="Text Encoder Learning Rate", value=2e-6
label="Text Encoder Learning Rate", value=1e-6
)
db_lr_scheduler = gr.Dropdown(
@ -510,11 +551,35 @@ def on_ui_tabs():
)
db_lr_warmup_steps = gr.Slider(
label="Learning Rate Warmup Steps",
value=0,
value=500,
step=5,
maximum=1000,
)
with gr.Column(visible=False) as lora_rank_col:
gr.HTML("Lora")
db_lora_unet_rank = gr.Slider(
label="Lora UNET Rank",
value=4,
minimum=2,
maximum=128,
step=2,
)
db_lora_txt_rank = gr.Slider(
label="Lora Text Encoder Rank",
value=4,
minimum=2,
maximum=128,
step=2,
)
db_lora_weight = gr.Slider(
label="Lora Weight (Alpha)",
value=0.8,
minimum=0.1,
maximum=1,
step=0.1,
)
with gr.Column():
gr.HTML(value="Image Processing")
db_resolution = gr.Slider(
@ -544,12 +609,15 @@ def on_ui_tabs():
)
db_mixed_precision = gr.Dropdown(
label="Mixed Precision",
value="no",
value=select_precision(),
choices=list_precisions(),
)
db_full_mixed_precision = gr.Checkbox(
label="Full Mixed Precision", value=True
)
db_attention = gr.Dropdown(
label="Memory Attention",
value="default",
value=select_attention(),
choices=list_attention(),
)
db_cache_latents = gr.Checkbox(
@ -563,7 +631,7 @@ def on_ui_tabs():
minimum=0,
maximum=1,
step=0.05,
value=0,
value=1.0,
visible=True,
)
db_offset_noise = gr.Slider(
@ -580,7 +648,7 @@ def on_ui_tabs():
)
db_clip_skip = gr.Slider(
label="Clip Skip",
value=1,
value=2,
minimum=1,
maximum=12,
step=1,
@ -843,35 +911,6 @@ def on_ui_tabs():
label="Generate a .ckpt file when training is canceled."
)
with gr.Column(visible=False) as lora_save_col:
gr.HTML("Lora")
db_lora_unet_rank = gr.Slider(
label="Lora UNET Rank",
value=4,
minimum=2,
maximum=128,
step=2,
)
db_lora_txt_rank = gr.Slider(
label="Lora Text Encoder Rank",
value=4,
minimum=2,
maximum=768,
step=2,
)
db_lora_weight = gr.Slider(
label="Lora Weight",
value=1,
minimum=0.1,
maximum=1,
step=0.1,
)
db_lora_txt_weight = gr.Slider(
label="Lora Text Weight",
value=1,
minimum=0.1,
maximum=1,
step=0.1,
)
db_save_lora_during = gr.Checkbox(
label="Generate lora weights when saving during training."
)
@ -1000,9 +1039,6 @@ def on_ui_tabs():
)
with gr.Tab("Testing", elem_id="TabDebug"):
gr.HTML(value="Experimental Settings")
db_full_mixed_precision = gr.Checkbox(
label="Full Mixed Precision", value=False
)
db_tomesd = gr.Slider(
value=0,
label="Token Merging (ToMe)",
@ -1059,8 +1095,8 @@ def on_ui_tabs():
db_status = gr.HTML(elem_id="db_status", value="")
db_progressbar = gr.HTML(elem_id="db_progressbar")
db_gallery = gr.Gallery(
label="Output", show_label=False, elem_id="db_gallery"
).style(grid=4)
label="Output", show_label=False, elem_id="db_gallery", columns=4
)
db_preview = gr.Image(elem_id="db_preview", visible=False)
db_prompt_list = gr.HTML(
elem_id="db_prompt_list", value="", visible=False
@ -1084,9 +1120,13 @@ def on_ui_tabs():
show_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
standard_lr,
lora_model,
_,
_,
_
) = disable_lora(use_lora)
(
lr_power,
@ -1103,6 +1143,7 @@ def on_ui_tabs():
show_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
lora_model,
scheduler,
@ -1144,6 +1185,7 @@ def on_ui_tabs():
db_use_ema,
db_use_lora_extended,
lora_save_col,
lora_rank_col,
lora_lr_row,
lora_model_row,
db_scheduler,
@ -1184,6 +1226,30 @@ def on_ui_tabs():
outputs=[db_stop_text_encoder],
)
def toggle_full_mixed_precision(full_mixed_precision):
if full_mixed_precision != "fp16":
return gr.update(visible=False)
else:
return gr.update(visible=True)
db_mixed_precision.change(
fn=toggle_full_mixed_precision,
inputs=[db_mixed_precision],
outputs=[db_full_mixed_precision],
)
def update_model_options(model_type):
if model_type == "SDXL":
return gr.update(value=1024)
else:
return gr.update(value=512)
db_model_type_select.change(
fn=update_model_options,
inputs=[db_model_type_select],
outputs=[db_resolution]
)
db_clear_secret.click(fn=clear_secret, inputs=[], outputs=[db_secret])
# Elements to update when progress changes
@ -1279,7 +1345,6 @@ def on_ui_tabs():
db_lora_model_name,
db_lora_txt_learning_rate,
db_lora_txt_rank,
db_lora_txt_weight,
db_lora_unet_rank,
db_lora_use_buggy_requires_grad,
db_lora_weight,
@ -1484,18 +1549,31 @@ def on_ui_tabs():
def disable_lora(x):
use_ema = gr.update(interactive=not x)
use_lora_extended = gr.update(visible=x)
use_lora_extended = gr.update(visible=False)
lora_save = gr.update(visible=x)
lora_rank = gr.update(visible=x)
lora_lr = gr.update(visible=x)
standard_lr = gr.update(visible=not x)
lora_model = gr.update(visible=x)
if x:
save_during =gr.update(label="Save LORA during training")
save_after = gr.update(label="Save LORA after training")
save_cancel = gr.update(label="Save LORA on cancel")
else:
save_during = gr.update(label="Save .safetensors during training")
save_after = gr.update(label="Save .safetensors after training")
save_cancel = gr.update(label="Save .safetensors on cancel")
return (
use_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
standard_lr,
lora_model,
save_during,
save_after,
save_cancel
)
def lr_scheduler_changed(sched):
@ -1545,9 +1623,13 @@ def on_ui_tabs():
db_use_ema,
db_use_lora_extended,
lora_save_col,
lora_rank_col,
lora_lr_row,
standard_lr_row,
lora_model_row,
db_save_ckpt_during,
db_save_ckpt_after,
db_save_ckpt_cancel
],
)
@ -1743,6 +1825,7 @@ def on_ui_tabs():
outputs=[db_gallery, db_status],
)
db_cancel.click(
fn=lambda: status.interrupt(),
inputs=[],