Add AUTO-compatible metadata, UI improvements
parent
bd162d689d
commit
8030472b60
|
|
@ -1,3 +1,4 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -9,8 +10,9 @@ from pydantic import BaseModel
|
|||
|
||||
from dreambooth import shared # noqa
|
||||
from dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dataclasses.ss_model_spec import build_metadata
|
||||
from dreambooth.utils.image_utils import get_scheduler_names # noqa
|
||||
from dreambooth.utils.utils import list_attention
|
||||
from dreambooth.utils.utils import list_attention, select_precision, select_attention
|
||||
|
||||
# Keys to save, replacing our dumb __init__ method
|
||||
save_keys = []
|
||||
|
|
@ -51,23 +53,22 @@ class DreamboothConfig(BaseModel):
|
|||
infer_ema: bool = False
|
||||
initial_revision: int = 0
|
||||
input_pertubation: bool = True
|
||||
learning_rate: float = 5e-6
|
||||
learning_rate: float = 2e-6
|
||||
learning_rate_min: float = 1e-6
|
||||
lifetime_revision: int = 0
|
||||
lora_learning_rate: float = 1e-4
|
||||
lora_model_name: str = ""
|
||||
lora_txt_learning_rate: float = 5e-5
|
||||
lora_txt_rank: int = 4
|
||||
lora_txt_weight: float = 1.0
|
||||
lora_unet_rank: int = 4
|
||||
lora_weight: float = 1.0
|
||||
lora_weight: float = 0.8
|
||||
lora_use_buggy_requires_grad: bool = False
|
||||
lr_cycles: int = 1
|
||||
lr_factor: float = 0.5
|
||||
lr_power: float = 1.0
|
||||
lr_scale_pos: float = 0.5
|
||||
lr_scheduler: str = "constant_with_warmup"
|
||||
lr_warmup_steps: int = 0
|
||||
lr_warmup_steps: int = 500
|
||||
max_token_length: int = 75
|
||||
min_snr_gamma: float = 0.0
|
||||
mixed_precision: str = "fp16"
|
||||
|
|
@ -143,6 +144,12 @@ class DreamboothConfig(BaseModel):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
model_name = sanitize_name(model_name)
|
||||
if "attention" not in kwargs:
|
||||
self.attention = select_attention()
|
||||
|
||||
if "mixed_precision" not in kwargs:
|
||||
self.mixed_precision = select_precision()
|
||||
|
||||
if "models_path" in kwargs:
|
||||
models_path = kwargs["models_path"]
|
||||
print(f"Using models path: {models_path}")
|
||||
|
|
@ -157,7 +164,6 @@ class DreamboothConfig(BaseModel):
|
|||
|
||||
if not self.use_lora:
|
||||
self.lora_model_name = ""
|
||||
|
||||
model_dir = os.path.join(models_path, model_name)
|
||||
# print(f"Model dir set to: {model_dir}")
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
|
|
@ -325,6 +331,85 @@ class DreamboothConfig(BaseModel):
|
|||
return os.path.join(self.model_dir, "working")
|
||||
return self.pretrained_model_name_or_path
|
||||
|
||||
def export_ss_metadata(self, state_dict=None):
|
||||
params = {}
|
||||
token_counts_path = os.path.join(self.model_dir, "token_counts.json")
|
||||
tags = None
|
||||
if os.path.exists(token_counts_path):
|
||||
with open(token_counts_path, "r") as f:
|
||||
tags = json.load(f)
|
||||
base_meta = build_metadata(
|
||||
state_dict=state_dict,
|
||||
v2 = "v2x" in self.model_type,
|
||||
v_parameterization = self.model_type == "v2x",
|
||||
sdxl = self.model_type == "SDXL",
|
||||
lora=self.use_lora,
|
||||
textual_inversion=False,
|
||||
timestamp=datetime.datetime.now().timestamp(),
|
||||
reso=(self.resolution, self.resolution),
|
||||
tags=tags,
|
||||
clip_skip=self.clip_skip
|
||||
)
|
||||
mappings = {
|
||||
'cache_latents': 'ss_cache_latents',
|
||||
'clip_skip': 'ss_clip_skip',
|
||||
'epoch': 'ss_epoch',
|
||||
'gradient_accumulation_steps': 'ss_gradient_accumulation_steps',
|
||||
'gradient_checkpointing': 'ss_gradient_checkpointing',
|
||||
'learning_rate': 'ss_learning_rate',
|
||||
'lr_scheduler': 'ss_lr_scheduler',
|
||||
'lr_warmup_steps': 'ss_lr_warmup_steps',
|
||||
'max_token_length': 'ss_max_token_length',
|
||||
'min_snr_gamma': 'ss_min_snr_gamma',
|
||||
'mixed_precision': 'ss_mixed_precision',
|
||||
'optimizer': 'ss_optimizer',
|
||||
'prior_loss_weight': 'ss_prior_loss_weight',
|
||||
'resolution': 'ss_resolution',
|
||||
'src': 'ss_sd_model_name',
|
||||
'shuffle_tags': 'ss_shuffle_captions'
|
||||
}
|
||||
for key, value in mappings.items():
|
||||
if hasattr(self, key):
|
||||
if value == "ss_resolution":
|
||||
res = getattr(self, key)
|
||||
params[value] = f"({res}, {res})"
|
||||
params["modelspec.resolution"] = f"{res}x{res}"
|
||||
elif value == "ss_sd_model_name":
|
||||
model_name = getattr(self, key)
|
||||
# Ensure model_name is only the model name, not the full path
|
||||
model_name = os.path.basename(model_name)
|
||||
params[value] = model_name
|
||||
params[value] = getattr(self, key)
|
||||
|
||||
# Enumerate all params convert each one to a string
|
||||
for key, value in params.items():
|
||||
# If the value is not a string, convert it to one
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
if key not in base_meta:
|
||||
base_meta[key] = value
|
||||
for key, value in base_meta.items():
|
||||
if isinstance(value, Dict):
|
||||
base_meta[key] = json.dumps(value)
|
||||
elif not isinstance(value, str):
|
||||
base_meta[key] = str(value)
|
||||
ss_sd_model_name = base_meta.get("ss_sd_model_name", "")
|
||||
if ss_sd_model_name != "":
|
||||
# Make SURE we don't have a full path here
|
||||
ss_sd_model_name = os.path.basename(ss_sd_model_name)
|
||||
if "/" in ss_sd_model_name:
|
||||
ss_sd_model_name = ss_sd_model_name.split("/")[-1]
|
||||
if "\\" in ss_sd_model_name:
|
||||
ss_sd_model_name = ss_sd_model_name.split("\\")[-1]
|
||||
base_meta["ss_sd_model_name"] = ss_sd_model_name
|
||||
if self.model_type == "SDXL":
|
||||
base_meta["sd_version"] = "SDXL"
|
||||
if "v2x" in self.model_type:
|
||||
base_meta["sd_version"] = "V2"
|
||||
if "v1x" in self.model_type:
|
||||
base_meta["sd_version"] = "V1"
|
||||
return base_meta
|
||||
|
||||
|
||||
def concepts_from_file(concepts_path: str):
|
||||
concepts = []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
|
||||
BASE_METADATA = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
"modelspec.architecture": None,
|
||||
"modelspec.implementation": None,
|
||||
"modelspec.title": None,
|
||||
"modelspec.resolution": None,
|
||||
# === Should ===
|
||||
"modelspec.description": None,
|
||||
"modelspec.author": None,
|
||||
"modelspec.date": None,
|
||||
# === Can ===
|
||||
"modelspec.license": None,
|
||||
"modelspec.tags": None,
|
||||
"modelspec.merged_from": None,
|
||||
"modelspec.prediction_type": None,
|
||||
"modelspec.timestep_range": None,
|
||||
"modelspec.encoder_layer": None,
|
||||
}
|
||||
|
||||
# 別に使うやつだけ定義
|
||||
MODELSPEC_TITLE = "modelspec.title"
|
||||
|
||||
ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
|
||||
|
||||
def load_bytes_in_safetensors(tensors):
|
||||
bytes = safetensors.torch.save(tensors)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
|
||||
return b.read()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(state_dict):
|
||||
# calculate each tensor one by one to reduce memory usage
|
||||
hash_sha256 = hashlib.sha256()
|
||||
for tensor in state_dict.values():
|
||||
single_tensor_sd = {"tensor": tensor}
|
||||
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
||||
hash_sha256.update(bytes_for_tensor)
|
||||
|
||||
return f"0x{hash_sha256.hexdigest()}"
|
||||
|
||||
|
||||
def update_hash_sha256(metadata: dict, state_dict: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_metadata(
|
||||
state_dict: Optional[dict],
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
timestamp: float,
|
||||
title: Optional[str] = None,
|
||||
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None,
|
||||
author: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
license: Optional[str] = None,
|
||||
tags: Optional[str] = None,
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
metadata = {}
|
||||
metadata.update(BASE_METADATA)
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
else:
|
||||
arch = ARCH_SD_V2_512
|
||||
else:
|
||||
arch = ARCH_SD_V1
|
||||
|
||||
metadata["ss_base_model_version"] = arch
|
||||
if lora:
|
||||
arch += f"/{ADAPTER_LORA}"
|
||||
elif textual_inversion:
|
||||
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
||||
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
# v1/v2 LoRA or Diffusers
|
||||
impl = IMPL_DIFFUSERS
|
||||
metadata["modelspec.implementation"] = impl
|
||||
|
||||
|
||||
if title is None:
|
||||
if lora:
|
||||
title = "LoRA"
|
||||
elif textual_inversion:
|
||||
title = "TextualInversion"
|
||||
else:
|
||||
title = "Checkpoint"
|
||||
title += f"@{timestamp}"
|
||||
metadata[MODELSPEC_TITLE] = title
|
||||
|
||||
if author is not None:
|
||||
metadata["modelspec.author"] = author
|
||||
else:
|
||||
del metadata["modelspec.author"]
|
||||
|
||||
if description is not None:
|
||||
metadata["modelspec.description"] = description
|
||||
else:
|
||||
del metadata["modelspec.description"]
|
||||
|
||||
if merged_from is not None:
|
||||
metadata["modelspec.merged_from"] = merged_from
|
||||
else:
|
||||
del metadata["modelspec.merged_from"]
|
||||
|
||||
if license is not None:
|
||||
metadata["modelspec.license"] = license
|
||||
else:
|
||||
del metadata["modelspec.license"]
|
||||
|
||||
if tags is not None:
|
||||
metadata["modelspec.tags"] = tags
|
||||
metadata["ss_tag_frequency"] = tags
|
||||
else:
|
||||
del metadata["modelspec.tags"]
|
||||
|
||||
# remove microsecond from time
|
||||
int_ts = int(timestamp)
|
||||
|
||||
# time to iso-8601 compliant date
|
||||
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
||||
metadata["modelspec.date"] = date
|
||||
|
||||
if reso is not None:
|
||||
# comma separated to tuple
|
||||
if isinstance(reso, str):
|
||||
reso = tuple(map(int, reso.split(",")))
|
||||
if len(reso) == 1:
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
else:
|
||||
reso = 512
|
||||
if isinstance(reso, int):
|
||||
reso = (reso, reso)
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
|
||||
if timesteps is not None:
|
||||
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
||||
timesteps = (timesteps, timesteps)
|
||||
if len(timesteps) == 1:
|
||||
timesteps = (timesteps[0], timesteps[0])
|
||||
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
||||
else:
|
||||
del metadata["modelspec.timestep_range"]
|
||||
|
||||
if clip_skip is not None:
|
||||
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
||||
else:
|
||||
del metadata["modelspec.encoder_layer"]
|
||||
|
||||
# # assert all values are filled
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
print(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
return metadata
|
||||
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
|
|
@ -86,7 +89,8 @@ secondary_keys = [
|
|||
]
|
||||
|
||||
|
||||
def convert_diffusers_to_kohya_lora(model_dict, path):
|
||||
def convert_diffusers_to_kohya_lora(path, metadata, alpha=0.8):
|
||||
model_dict = safetensors.torch.load_file(path)
|
||||
new_model_dict = {}
|
||||
alpha_keys = []
|
||||
# Replace the things
|
||||
|
|
@ -109,6 +113,9 @@ def convert_diffusers_to_kohya_lora(model_dict, path):
|
|||
# Add missing alpha keys
|
||||
for k in alpha_keys:
|
||||
if k not in new_model_dict:
|
||||
print(f"Adding missing alpha key {k}")
|
||||
new_model_dict[k] = torch.tensor(0.8)
|
||||
safetensors.torch.save_file(new_model_dict, path)
|
||||
new_model_dict[k] = torch.tensor(alpha)
|
||||
conv_path = path.replace(".safetensors", "_auto.safetensors")
|
||||
safetensors.torch.save_file(new_model_dict, conv_path, metadata=metadata)
|
||||
# Delete the file at path, move the new file to path
|
||||
os.remove(path)
|
||||
os.rename(conv_path, path)
|
||||
|
|
|
|||
|
|
@ -506,7 +506,8 @@ def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_model
|
|||
printi(f"Saving checkpoint to {checkpoint_path}...", log=log)
|
||||
if save_safetensors:
|
||||
safe_dict, json_dict = split_dict(state_dict, pbar)
|
||||
safetensors.torch.save_file(safe_dict, checkpoint_path, json_dict)
|
||||
meta = config.export_ss_metadata()
|
||||
safetensors.torch.save_file(safe_dict, checkpoint_path, meta)
|
||||
else:
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
cfg_file = None
|
||||
|
|
|
|||
|
|
@ -375,7 +375,8 @@ def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_model
|
|||
|
||||
printi(f"Saving checkpoint to {checkpoint_path}...", log=log)
|
||||
if save_safetensors:
|
||||
save_file(state_dict, checkpoint_path)
|
||||
meta = config.export_ss_metadata()
|
||||
save_file(state_dict, checkpoint_path, meta)
|
||||
else:
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from contextlib import ExitStack
|
|||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors.torch
|
||||
import tomesd
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
|
|
@ -40,7 +41,7 @@ from dreambooth.dataclasses.prompt_data import PromptData
|
|||
from dreambooth.dataclasses.train_result import TrainResult
|
||||
from dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from dreambooth.deis_velocity import get_velocity, compute_snr
|
||||
from dreambooth.deis_velocity import get_velocity
|
||||
from dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
|
||||
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
|
||||
|
|
@ -55,9 +56,9 @@ from dreambooth.utils.model_utils import (
|
|||
disable_safe_unpickle,
|
||||
enable_safe_unpickle,
|
||||
xformerify,
|
||||
torch2ify, unet_attn_processors_state_dict,
|
||||
torch2ify, unet_attn_processors_state_dict
|
||||
)
|
||||
from dreambooth.utils.text_utils import encode_hidden_state
|
||||
from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
|
||||
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
|
||||
patch_accelerator_for_fp16_training)
|
||||
from dreambooth.webhook import send_training_update
|
||||
|
|
@ -278,6 +279,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args, class_gen_method=class_gen_method, accelerator=accelerator, ui=False, pbar=pbar2
|
||||
)
|
||||
|
||||
save_token_counts(args, instance_prompts, 10)
|
||||
|
||||
if status.interrupted:
|
||||
result.msg = "Training interrupted."
|
||||
stop_profiler(profiler)
|
||||
|
|
@ -1050,7 +1053,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
else:
|
||||
model_dir = shared.models_path
|
||||
loras_dir = os.path.join(model_dir, "Lora")
|
||||
|
||||
delete_tmp_lora = False
|
||||
# Update the temp path if we just need to save an image
|
||||
if save_image:
|
||||
logger.debug("Save image is set.")
|
||||
|
|
@ -1058,9 +1061,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if not save_lora:
|
||||
logger.debug("Saving lora weights instead of checkpoint, using temp dir.")
|
||||
save_lora = True
|
||||
delete_tmp_lora = True
|
||||
save_checkpoint = False
|
||||
save_diffusers = False
|
||||
loras_dir = f"{loras_dir}_temp"
|
||||
os.makedirs(loras_dir, exist_ok=True)
|
||||
elif not save_diffusers:
|
||||
logger.debug("Saving checkpoint, using temp dir.")
|
||||
|
|
@ -1081,6 +1084,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors")
|
||||
|
||||
with accelerator.autocast(), torch.inference_mode():
|
||||
|
||||
def lora_save_function(weights, filename):
|
||||
metadata = args.export_ss_metadata()
|
||||
logger.debug(f"Saving lora to {filename}")
|
||||
safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
if save_lora:
|
||||
# TODO: Add a version for the lora model?
|
||||
pbar2.reset(1)
|
||||
|
|
@ -1102,7 +1111,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
weight_name=lora_save_file,
|
||||
safe_serialization=True,
|
||||
save_function=convert_diffusers_to_kohya_lora
|
||||
save_function=lora_save_function
|
||||
)
|
||||
scheduler_args = {}
|
||||
|
||||
|
|
@ -1123,8 +1132,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
weight_name=lora_save_file,
|
||||
safe_serialization=True,
|
||||
save_function=convert_diffusers_to_kohya_lora
|
||||
safe_serialization=True
|
||||
)
|
||||
s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
|
||||
s_pipeline.scheduler.config)
|
||||
|
|
@ -1221,6 +1229,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False)
|
||||
if args.use_lora:
|
||||
s_pipeline.load_lora_weights(lora_save_file)
|
||||
|
||||
try:
|
||||
s_pipeline.enable_vae_tiling()
|
||||
s_pipeline.enable_vae_slicing()
|
||||
|
|
@ -1342,6 +1351,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
update_status({"images": last_samples, "prompts": last_prompts})
|
||||
pbar2.update()
|
||||
|
||||
|
||||
if args.cache_latents:
|
||||
printm("Unloading vae.")
|
||||
del vae
|
||||
|
|
@ -1362,6 +1372,15 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda")
|
||||
|
||||
cleanup()
|
||||
|
||||
# Save the lora weights if we are saving the model
|
||||
if os.path.isfile(lora_save_file) and not delete_tmp_lora:
|
||||
meta = args.export_ss_metadata()
|
||||
convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight)
|
||||
else:
|
||||
if os.path.isfile(lora_save_file):
|
||||
os.remove(lora_save_file)
|
||||
|
||||
printm("Completed saving weights.")
|
||||
pbar2.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -950,7 +950,7 @@ def create_model(
|
|||
res = 512
|
||||
elif model_type == "v2x":
|
||||
res = 768
|
||||
elif model_type == "sdxl":
|
||||
elif model_type == "SDXL":
|
||||
res = 1024
|
||||
sh = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
|
@ -294,4 +295,42 @@ def torch2ify(unet):
|
|||
return unet
|
||||
|
||||
def is_xformers_available():
|
||||
pass
|
||||
pass
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
|
||||
with open(filename, mode="rb") as file:
|
||||
# Read metadata length
|
||||
metadata_len = int.from_bytes(file.read(8), "little")
|
||||
|
||||
# Read the metadata based on its length
|
||||
json_data = file.read(metadata_len).decode('utf-8')
|
||||
|
||||
res = {}
|
||||
|
||||
# Check if it's a valid JSON string
|
||||
try:
|
||||
json_obj = json.loads(json_data)
|
||||
except json.JSONDecodeError:
|
||||
return res
|
||||
|
||||
# Extract metadata
|
||||
metadata = json_obj.get("__metadata__", {})
|
||||
if not isinstance(metadata, dict):
|
||||
return res
|
||||
|
||||
# Process the metadata to handle nested JSON strings
|
||||
for k, v in metadata.items():
|
||||
# if not isinstance(v, str):
|
||||
# raise ValueError("All values in __metadata__ must be strings")
|
||||
|
||||
# If the string value looks like a JSON string, attempt to parse it
|
||||
if v.startswith('{'):
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
res[k] = v
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ def list_attention():
|
|||
else:
|
||||
return ["default"]
|
||||
|
||||
def select_attention():
|
||||
attentions = list_attention()
|
||||
# Return the last element
|
||||
return attentions[-1]
|
||||
|
||||
def list_precisions():
|
||||
precisions = ["no", "fp16"]
|
||||
|
|
@ -241,6 +245,11 @@ def list_precisions():
|
|||
|
||||
return precisions
|
||||
|
||||
def select_precision():
|
||||
precisions = list_precisions()
|
||||
# Return the last element
|
||||
return precisions[-1]
|
||||
|
||||
|
||||
def list_schedulers():
|
||||
return [
|
||||
|
|
|
|||
173
scripts/main.py
173
scripts/main.py
|
|
@ -1,4 +1,5 @@
|
|||
import importlib
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ from dreambooth.utils.utils import (
|
|||
wrap_gpu_call,
|
||||
printm,
|
||||
list_optimizer,
|
||||
list_schedulers,
|
||||
list_schedulers, select_precision, select_attention,
|
||||
)
|
||||
from dreambooth.webhook import save_and_test_webhook
|
||||
from helpers.log_parser import LogParser
|
||||
|
|
@ -60,6 +61,45 @@ delete_symbol = "\U0001F5D1" # 🗑️
|
|||
update_symbol = "\U0001F51D" # 🠝
|
||||
log_parser = LogParser()
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
|
||||
with open(filename, mode="rb") as file:
|
||||
# Read metadata length
|
||||
metadata_len = int.from_bytes(file.read(8), "little")
|
||||
|
||||
# Read the metadata based on its length
|
||||
json_data = file.read(metadata_len).decode('utf-8')
|
||||
|
||||
res = {}
|
||||
|
||||
# Check if it's a valid JSON string
|
||||
try:
|
||||
json_obj = json.loads(json_data)
|
||||
except json.JSONDecodeError:
|
||||
return res
|
||||
|
||||
# Extract metadata
|
||||
metadata = json_obj.get("__metadata__", {})
|
||||
if not isinstance(metadata, dict):
|
||||
return res
|
||||
|
||||
# Process the metadata to handle nested JSON strings
|
||||
for k, v in metadata.items():
|
||||
# if not isinstance(v, str):
|
||||
# raise ValueError("All values in __metadata__ must be strings")
|
||||
|
||||
# If the string value looks like a JSON string, attempt to parse it
|
||||
if v.startswith('{'):
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
res[k] = v
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def get_sd_models():
|
||||
sd_models.list_models()
|
||||
|
|
@ -354,7 +394,8 @@ def on_ui_tabs():
|
|||
db_new_model_shared_src = gr.Dropdown(
|
||||
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
|
||||
choices=sorted(get_shared_models()),
|
||||
value=""
|
||||
value="",
|
||||
visible=False
|
||||
)
|
||||
create_refresh_button(
|
||||
db_new_model_shared_src,
|
||||
|
|
@ -388,7 +429,7 @@ def on_ui_tabs():
|
|||
value=False,
|
||||
visible=False,
|
||||
)
|
||||
db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False)
|
||||
db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False)
|
||||
db_train_inpainting = gr.Checkbox(
|
||||
label="Train Inpainting Model",
|
||||
value=False,
|
||||
|
|
@ -471,7 +512,7 @@ def on_ui_tabs():
|
|||
label="Learning Rate", value=2e-6
|
||||
)
|
||||
db_txt_learning_rate = gr.Number(
|
||||
label="Text Encoder Learning Rate", value=2e-6
|
||||
label="Text Encoder Learning Rate", value=1e-6
|
||||
)
|
||||
|
||||
db_lr_scheduler = gr.Dropdown(
|
||||
|
|
@ -510,11 +551,35 @@ def on_ui_tabs():
|
|||
)
|
||||
db_lr_warmup_steps = gr.Slider(
|
||||
label="Learning Rate Warmup Steps",
|
||||
value=0,
|
||||
value=500,
|
||||
step=5,
|
||||
maximum=1000,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as lora_rank_col:
|
||||
gr.HTML("Lora")
|
||||
db_lora_unet_rank = gr.Slider(
|
||||
label="Lora UNET Rank",
|
||||
value=4,
|
||||
minimum=2,
|
||||
maximum=128,
|
||||
step=2,
|
||||
)
|
||||
db_lora_txt_rank = gr.Slider(
|
||||
label="Lora Text Encoder Rank",
|
||||
value=4,
|
||||
minimum=2,
|
||||
maximum=128,
|
||||
step=2,
|
||||
)
|
||||
db_lora_weight = gr.Slider(
|
||||
label="Lora Weight (Alpha)",
|
||||
value=0.8,
|
||||
minimum=0.1,
|
||||
maximum=1,
|
||||
step=0.1,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
gr.HTML(value="Image Processing")
|
||||
db_resolution = gr.Slider(
|
||||
|
|
@ -544,12 +609,15 @@ def on_ui_tabs():
|
|||
)
|
||||
db_mixed_precision = gr.Dropdown(
|
||||
label="Mixed Precision",
|
||||
value="no",
|
||||
value=select_precision(),
|
||||
choices=list_precisions(),
|
||||
)
|
||||
db_full_mixed_precision = gr.Checkbox(
|
||||
label="Full Mixed Precision", value=True
|
||||
)
|
||||
db_attention = gr.Dropdown(
|
||||
label="Memory Attention",
|
||||
value="default",
|
||||
value=select_attention(),
|
||||
choices=list_attention(),
|
||||
)
|
||||
db_cache_latents = gr.Checkbox(
|
||||
|
|
@ -563,7 +631,7 @@ def on_ui_tabs():
|
|||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.05,
|
||||
value=0,
|
||||
value=1.0,
|
||||
visible=True,
|
||||
)
|
||||
db_offset_noise = gr.Slider(
|
||||
|
|
@ -580,7 +648,7 @@ def on_ui_tabs():
|
|||
)
|
||||
db_clip_skip = gr.Slider(
|
||||
label="Clip Skip",
|
||||
value=1,
|
||||
value=2,
|
||||
minimum=1,
|
||||
maximum=12,
|
||||
step=1,
|
||||
|
|
@ -843,35 +911,6 @@ def on_ui_tabs():
|
|||
label="Generate a .ckpt file when training is canceled."
|
||||
)
|
||||
with gr.Column(visible=False) as lora_save_col:
|
||||
gr.HTML("Lora")
|
||||
db_lora_unet_rank = gr.Slider(
|
||||
label="Lora UNET Rank",
|
||||
value=4,
|
||||
minimum=2,
|
||||
maximum=128,
|
||||
step=2,
|
||||
)
|
||||
db_lora_txt_rank = gr.Slider(
|
||||
label="Lora Text Encoder Rank",
|
||||
value=4,
|
||||
minimum=2,
|
||||
maximum=768,
|
||||
step=2,
|
||||
)
|
||||
db_lora_weight = gr.Slider(
|
||||
label="Lora Weight",
|
||||
value=1,
|
||||
minimum=0.1,
|
||||
maximum=1,
|
||||
step=0.1,
|
||||
)
|
||||
db_lora_txt_weight = gr.Slider(
|
||||
label="Lora Text Weight",
|
||||
value=1,
|
||||
minimum=0.1,
|
||||
maximum=1,
|
||||
step=0.1,
|
||||
)
|
||||
db_save_lora_during = gr.Checkbox(
|
||||
label="Generate lora weights when saving during training."
|
||||
)
|
||||
|
|
@ -1000,9 +1039,6 @@ def on_ui_tabs():
|
|||
)
|
||||
with gr.Tab("Testing", elem_id="TabDebug"):
|
||||
gr.HTML(value="Experimental Settings")
|
||||
db_full_mixed_precision = gr.Checkbox(
|
||||
label="Full Mixed Precision", value=False
|
||||
)
|
||||
db_tomesd = gr.Slider(
|
||||
value=0,
|
||||
label="Token Merging (ToMe)",
|
||||
|
|
@ -1059,8 +1095,8 @@ def on_ui_tabs():
|
|||
db_status = gr.HTML(elem_id="db_status", value="")
|
||||
db_progressbar = gr.HTML(elem_id="db_progressbar")
|
||||
db_gallery = gr.Gallery(
|
||||
label="Output", show_label=False, elem_id="db_gallery"
|
||||
).style(grid=4)
|
||||
label="Output", show_label=False, elem_id="db_gallery", columns=4
|
||||
)
|
||||
db_preview = gr.Image(elem_id="db_preview", visible=False)
|
||||
db_prompt_list = gr.HTML(
|
||||
elem_id="db_prompt_list", value="", visible=False
|
||||
|
|
@ -1084,9 +1120,13 @@ def on_ui_tabs():
|
|||
show_ema,
|
||||
use_lora_extended,
|
||||
lora_save,
|
||||
lora_rank,
|
||||
lora_lr,
|
||||
standard_lr,
|
||||
lora_model,
|
||||
_,
|
||||
_,
|
||||
_
|
||||
) = disable_lora(use_lora)
|
||||
(
|
||||
lr_power,
|
||||
|
|
@ -1103,6 +1143,7 @@ def on_ui_tabs():
|
|||
show_ema,
|
||||
use_lora_extended,
|
||||
lora_save,
|
||||
lora_rank,
|
||||
lora_lr,
|
||||
lora_model,
|
||||
scheduler,
|
||||
|
|
@ -1144,6 +1185,7 @@ def on_ui_tabs():
|
|||
db_use_ema,
|
||||
db_use_lora_extended,
|
||||
lora_save_col,
|
||||
lora_rank_col,
|
||||
lora_lr_row,
|
||||
lora_model_row,
|
||||
db_scheduler,
|
||||
|
|
@ -1184,6 +1226,30 @@ def on_ui_tabs():
|
|||
outputs=[db_stop_text_encoder],
|
||||
)
|
||||
|
||||
def toggle_full_mixed_precision(full_mixed_precision):
|
||||
if full_mixed_precision != "fp16":
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
|
||||
db_mixed_precision.change(
|
||||
fn=toggle_full_mixed_precision,
|
||||
inputs=[db_mixed_precision],
|
||||
outputs=[db_full_mixed_precision],
|
||||
)
|
||||
|
||||
def update_model_options(model_type):
|
||||
if model_type == "SDXL":
|
||||
return gr.update(value=1024)
|
||||
else:
|
||||
return gr.update(value=512)
|
||||
|
||||
db_model_type_select.change(
|
||||
fn=update_model_options,
|
||||
inputs=[db_model_type_select],
|
||||
outputs=[db_resolution]
|
||||
)
|
||||
|
||||
db_clear_secret.click(fn=clear_secret, inputs=[], outputs=[db_secret])
|
||||
|
||||
# Elements to update when progress changes
|
||||
|
|
@ -1279,7 +1345,6 @@ def on_ui_tabs():
|
|||
db_lora_model_name,
|
||||
db_lora_txt_learning_rate,
|
||||
db_lora_txt_rank,
|
||||
db_lora_txt_weight,
|
||||
db_lora_unet_rank,
|
||||
db_lora_use_buggy_requires_grad,
|
||||
db_lora_weight,
|
||||
|
|
@ -1484,18 +1549,31 @@ def on_ui_tabs():
|
|||
|
||||
def disable_lora(x):
|
||||
use_ema = gr.update(interactive=not x)
|
||||
use_lora_extended = gr.update(visible=x)
|
||||
use_lora_extended = gr.update(visible=False)
|
||||
lora_save = gr.update(visible=x)
|
||||
lora_rank = gr.update(visible=x)
|
||||
lora_lr = gr.update(visible=x)
|
||||
standard_lr = gr.update(visible=not x)
|
||||
lora_model = gr.update(visible=x)
|
||||
if x:
|
||||
save_during =gr.update(label="Save LORA during training")
|
||||
save_after = gr.update(label="Save LORA after training")
|
||||
save_cancel = gr.update(label="Save LORA on cancel")
|
||||
else:
|
||||
save_during = gr.update(label="Save .safetensors during training")
|
||||
save_after = gr.update(label="Save .safetensors after training")
|
||||
save_cancel = gr.update(label="Save .safetensors on cancel")
|
||||
return (
|
||||
use_ema,
|
||||
use_lora_extended,
|
||||
lora_save,
|
||||
lora_rank,
|
||||
lora_lr,
|
||||
standard_lr,
|
||||
lora_model,
|
||||
save_during,
|
||||
save_after,
|
||||
save_cancel
|
||||
)
|
||||
|
||||
def lr_scheduler_changed(sched):
|
||||
|
|
@ -1545,9 +1623,13 @@ def on_ui_tabs():
|
|||
db_use_ema,
|
||||
db_use_lora_extended,
|
||||
lora_save_col,
|
||||
lora_rank_col,
|
||||
lora_lr_row,
|
||||
standard_lr_row,
|
||||
lora_model_row,
|
||||
db_save_ckpt_during,
|
||||
db_save_ckpt_after,
|
||||
db_save_ckpt_cancel
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -1743,6 +1825,7 @@ def on_ui_tabs():
|
|||
outputs=[db_gallery, db_status],
|
||||
)
|
||||
|
||||
|
||||
db_cancel.click(
|
||||
fn=lambda: status.interrupt(),
|
||||
inputs=[],
|
||||
|
|
|
|||
Loading…
Reference in New Issue