sd_dreambooth_extension/dreambooth/dataclasses/db_config.py

525 lines
19 KiB
Python

import datetime
import json
import logging
import os
import traceback
from pathlib import Path
from typing import List, Dict
from pydantic import BaseModel
from dreambooth import shared # noqa
from dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dataclasses.ss_model_spec import build_metadata
from dreambooth.utils.image_utils import get_scheduler_names # noqa
from dreambooth.utils.utils import list_attention, select_precision, select_attention
# Keys to save, replacing our dumb __init__ method
save_keys = []
# Keys to return to the ui when Load Settings is clicked.
ui_keys = []
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._- "))
class DreamboothConfig(BaseModel):
# These properties MUST be sorted alphabetically
weight_decay: float = 0.01
attention: str = "xformers"
cache_latents: bool = True
clip_skip: int = 1
concepts_list: List[Dict] = []
concepts_path: str = ""
custom_model_name: str = ""
deterministic: bool = False
disable_class_matching: bool = False
disable_logging: bool = False
ema_predict: bool = False
epoch: int = 0
epoch_pause_frequency: int = 0
epoch_pause_time: int = 0
freeze_clip_normalization: bool = False
full_mixed_precision: bool = True
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = True
gradient_set_to_none: bool = True
graph_smoothing: int = 50
half_model: bool = False
has_ema: bool = False
hflip: bool = False
infer_ema: bool = False
initial_revision: int = 0
input_pertubation: bool = True
learning_rate: float = 2e-6
learning_rate_min: float = 1e-6
lifetime_revision: int = 0
lora_learning_rate: float = 1e-4
lora_model_name: str = ""
lora_txt_learning_rate: float = 5e-5
lora_txt_rank: int = 4
lora_unet_rank: int = 4
lora_weight: float = 0.8
lora_use_buggy_requires_grad: bool = False
lr_cycles: int = 1
lr_factor: float = 0.5
lr_power: float = 1.0
lr_scale_pos: float = 0.5
lr_scheduler: str = "constant_with_warmup"
lr_warmup_steps: int = 500
max_token_length: int = 75
min_snr_gamma: float = 0.0
use_dream: bool = False
dream_detail_preservation: float = 0.5
freeze_spectral_norm: bool = False
mixed_precision: str = "fp16"
model_dir: str = ""
model_name: str = ""
model_path: str = ""
model_type: str = "v1x"
noise_scheduler: str = "DDPM"
num_train_epochs: int = 100
offset_noise: float = 0
optimizer: str = "8bit AdamW"
pad_tokens: bool = True
pretrained_model_name_or_path: str = ""
pretrained_vae_name_or_path: str = ""
prior_loss_scale: bool = False
prior_loss_target: int = 100
prior_loss_weight: float = 0.75
prior_loss_weight_min: float = 0.1
resolution: int = 512
revision: int = 0
sample_batch_size: int = 1
sanity_prompt: str = ""
sanity_seed: int = 420420
save_ckpt_after: bool = True
save_ckpt_cancel: bool = False
save_ckpt_during: bool = True
save_ema: bool = True
save_embedding_every: int = 25
save_lora_after: bool = True
save_lora_cancel: bool = False
save_lora_during: bool = True
save_lora_for_extra_net: bool = True
save_preview_every: int = 5
save_safetensors: bool = True
save_state_after: bool = False
save_state_cancel: bool = False
save_state_during: bool = False
scheduler: str = "ddim"
shared_diffusers_path: str = ""
shuffle_tags: bool = True
snapshot: str = ""
split_loss: bool = True
src: str = ""
stop_text_encoder: float = 1.0
strict_tokens: bool = False
dynamic_img_norm: bool = False
tenc_weight_decay: float = 0.01
tenc_grad_clip_norm: float = 0.00
tomesd: float = 0
train_batch_size: int = 1
train_imagic: bool = False
train_unet: bool = True
train_unfrozen: bool = True
txt_learning_rate: float = 5e-6
use_concepts: bool = False
use_ema: bool = True
use_lora: bool = False
use_lora_extended: bool = False
use_shared_src: bool = False,
use_subdir: bool = False
v2: bool = False
def __init__(
self,
model_name: str = "",
model_dir: str = "",
v2: bool = False,
src: str = "",
resolution: int = 512,
**kwargs
):
super().__init__(**kwargs)
model_name = sanitize_name(model_name)
if "attention" not in kwargs:
self.attention = select_attention()
if "mixed_precision" not in kwargs:
self.mixed_precision = select_precision()
if "models_path" in kwargs:
models_path = kwargs["models_path"]
print(f"Using models path: {models_path}")
else:
models_path = shared.dreambooth_models_path
if models_path == "" or models_path is None:
models_path = os.path.join(shared.models_path, "dreambooth")
# If we're using the new UI, this should be populated, so load models from here.
if len(shared.paths):
models_path = os.path.join(shared.paths["models"], "dreambooth")
if not self.use_lora:
self.lora_model_name = ""
model_dir = os.path.join(models_path, model_name)
# print(f"Model dir set to: {model_dir}")
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
os.makedirs(working_dir)
self.model_name = model_name
self.model_dir = model_dir
self.pretrained_model_name_or_path = working_dir
self.resolution = resolution
self.src = src
self.scheduler = "ddim"
self.v2 = v2
# Actually save as a file
def save(self, backup=False):
"""
Save the config file
"""
models_path = self.model_dir
logger = logging.getLogger(__name__)
logger.debug("Saving to %s", models_path)
if os.name == 'nt' and '/' in models_path:
# replace linux path separators with windows path separators
models_path = models_path.replace('/', '\\')
elif os.name == 'posix' and '\\' in models_path:
# replace windows path separators with linux path separators
models_path = models_path.replace('\\', '/')
self.model_dir = models_path
config_file = os.path.join(models_path, "db_config.json")
if backup:
backup_dir = os.path.join(models_path, "backups")
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json")
with open(config_file, "w") as outfile:
json.dump(self.__dict__, outfile, indent=4)
def load_params(self, params_dict):
sched_swap = False
for key, value in params_dict.items():
if "db_" in key:
key = key.replace("db_", "")
if key == "attention" and value == "flash_attention":
value = list_attention()[-1]
print(f"Replacing flash attention in config to {value}")
if key == "scheduler":
schedulers = get_scheduler_names()
if value not in schedulers:
sched_swap = True
for scheduler in schedulers:
if value.lower() in scheduler.lower():
print(f"Updating scheduler name to: {scheduler}")
value = scheduler
break
if hasattr(self, key):
key, value = self.validate_param(key, value)
setattr(self, key, value)
if sched_swap:
self.save()
@staticmethod
def validate_param(key, value):
replaced_params = {
# "old_key" : {
# "new_key": "...",
# "values": [{
# "old": ["...", "..."]
# "new": "..."
# }]
# }
"weight_decay": {
"new_key": "weight_decay",
},
"deis_train_scheduler": {
"new_key": "noise_scheduler",
"values": [{
"old": [True],
"new": "DDPM"
}],
},
"optimizer": {
"values": [{
"old": ["8Bit Adam"],
"new": "8bit AdamW"
}],
},
"save_safetensors": {
"values": [{
"old": [False],
"new": True
}],
}
}
if key in replaced_params.keys():
replacement = replaced_params[key]
if "new_key" in replacement:
key = replacement["new_key"]
if "values" in replacement:
for _value in replacement["values"]:
if value in _value["old"]:
value = _value["new"]
return key, value
# Pass a dict and return a list of Concept objects
def concepts(self, required: int = -1):
concepts = []
c_idx = 0
# If using a file for concepts and not requesting from UI, load from file
if self.use_concepts and self.concepts_path and required == -1:
concepts_list = concepts_from_file(self.concepts_path)
# Otherwise, use 'stored' list
else:
concepts_list = self.concepts_list
if required == -1:
required = len(concepts_list)
for concept_dict in concepts_list:
concept = Concept(input_dict=concept_dict)
if concept.is_valid:
if concept.class_data_dir == "" or concept.class_data_dir is None:
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
concepts.append(concept)
c_idx += 1
missing = len(concepts) - required
if missing > 0:
concepts.extend([Concept(None)] * missing)
return concepts
def refresh(self):
"""
Reload self from file
"""
models_path = shared.dreambooth_models_path
if models_path == "" or models_path is None:
models_path = os.path.join(shared.models_path, "dreambooth")
config_file = os.path.join(models_path, self.model_name, "db_config.json")
try:
with open(config_file, 'r') as openfile:
config_dict = json.load(openfile)
self.load_params(config_dict)
shared.db_model_config = self
except Exception as e:
print(f"Exception loading config: {e}")
traceback.print_exc()
return None
def get_pretrained_model_name_or_path(self):
if self.shared_diffusers_path != "" and not self.use_lora:
raise Exception(f"shared_diffusers_path is \"{self.shared_diffusers_path}\" but use_lora is false")
if self.shared_diffusers_path != "":
return self.shared_diffusers_path
if not self.pretrained_model_name_or_path or self.pretrained_model_name_or_path == "":
return os.path.join(self.model_dir, "working")
return self.pretrained_model_name_or_path
def export_ss_metadata(self, state_dict=None):
params = {}
token_counts_path = os.path.join(self.model_dir, "token_counts.json")
bucket_json_file = os.path.join(self.model_dir, "bucket_counts.json")
bucket_counts = {}
tags = None
if os.path.exists(token_counts_path):
with open(token_counts_path, "r") as f:
tags = json.load(f)
if os.path.exists(bucket_json_file):
with open(bucket_json_file, "r") as f:
bucket_counts = json.load(f)
base_meta = build_metadata(
state_dict=state_dict,
v2 = "v2x" in self.model_type,
v_parameterization = self.model_type == "v2x",
sdxl = self.model_type == "SDXL",
lora=self.use_lora,
textual_inversion=False,
timestamp=datetime.datetime.now().timestamp(),
reso=(self.resolution, self.resolution),
tags=tags,
buckets=bucket_counts,
clip_skip=self.clip_skip
)
mappings = {
'cache_latents': 'ss_cache_latents',
'clip_skip': 'ss_clip_skip',
'epoch': 'ss_epoch',
'gradient_accumulation_steps': 'ss_gradient_accumulation_steps',
'gradient_checkpointing': 'ss_gradient_checkpointing',
'learning_rate': 'ss_learning_rate',
'lr_scheduler': 'ss_lr_scheduler',
'lr_warmup_steps': 'ss_lr_warmup_steps',
'max_token_length': 'ss_max_token_length',
'min_snr_gamma': 'ss_min_snr_gamma',
'mixed_precision': 'ss_mixed_precision',
'optimizer': 'ss_optimizer',
'prior_loss_weight': 'ss_prior_loss_weight',
'resolution': 'ss_resolution',
'src': 'ss_sd_model_name',
'shuffle_tags': 'ss_shuffle_captions'
}
for key, value in mappings.items():
if hasattr(self, key):
if value == "ss_resolution":
res = getattr(self, key)
params[value] = f"({res}, {res})"
params["modelspec.resolution"] = f"{res}x{res}"
elif value == "ss_sd_model_name":
model_name = getattr(self, key)
# Ensure model_name is only the model name, not the full path
model_name = os.path.basename(model_name)
params[value] = model_name
params[value] = getattr(self, key)
# Enumerate all params convert each one to a string
for key, value in params.items():
# If the value is not a string, convert it to one
if not isinstance(value, str):
value = str(value)
if key not in base_meta:
base_meta[key] = value
for key, value in base_meta.items():
if isinstance(value, Dict):
base_meta[key] = json.dumps(value)
elif not isinstance(value, str):
base_meta[key] = str(value)
ss_sd_model_name = base_meta.get("ss_sd_model_name", "")
if ss_sd_model_name != "":
# Make SURE we don't have a full path here
ss_sd_model_name = os.path.basename(ss_sd_model_name)
if "/" in ss_sd_model_name:
ss_sd_model_name = ss_sd_model_name.split("/")[-1]
if "\\" in ss_sd_model_name:
ss_sd_model_name = ss_sd_model_name.split("\\")[-1]
base_meta["ss_sd_model_name"] = ss_sd_model_name
if self.model_type == "SDXL":
base_meta["sd_version"] = "SDXL"
if "v2x" in self.model_type:
base_meta["sd_version"] = "V2"
if "v1x" in self.model_type:
base_meta["sd_version"] = "V1"
return base_meta
def concepts_from_file(concepts_path: str):
concepts = []
if os.path.exists(concepts_path) and os.path.isfile(concepts_path):
try:
with open(concepts_path, "r") as concepts_file:
concepts_str = concepts_file.read()
except Exception as e:
print(f"Exception opening concepts file: {e}")
else:
concepts_str = concepts_path
try:
concepts_data = json.loads(concepts_str)
for concept_data in concepts_data:
concepts_path_dir = Path(concepts_path).parent # Get which folder is JSON file reside
instance_data_dir = concept_data.get("instance_data_dir")
if not os.path.isabs(instance_data_dir):
print(f"Rebuilding portable concepts path: {concepts_path_dir} + {instance_data_dir}")
concept_data["instance_data_dir"] = os.path.join(concepts_path_dir, instance_data_dir)
concept = Concept(input_dict=concept_data)
if concept.is_valid:
concepts.append(concept.__dict__)
except Exception as e:
print(f"Exception parsing concepts: {e}")
print(f"Loaded concepts: {concepts}")
return concepts
def save_config(*args):
params = list(args)
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
params_dict = dict(zip(save_keys, params))
concepts_list = []
# If using a concepts file/string, keep concepts_list empty.
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
concepts_list = []
params_dict["concepts_list"] = concepts_list
else:
for concept_key in concept_keys:
concept_dict = {}
for key, param in params_dict.items():
if concept_key in key and param is not None:
concept_dict[key.replace(concept_key, "")] = param
concept_test = Concept(concept_dict)
if concept_test.is_valid:
concepts_list.append(concept_test.__dict__)
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
if len(concepts_list) and not len(existing_concepts):
params_dict["concepts_list"] = concepts_list
model_name = params_dict["db_model_name"]
if model_name is None or model_name == "":
print("Invalid model name.")
return
config = from_file(model_name)
if config is None:
config = DreamboothConfig(model_name)
config.load_params(params_dict)
shared.db_model_config = config
config.save()
def from_file(model_name, model_dir=None):
"""
Load config data from UI
Args:
model_name: The config to load
model_dir: If specified, override the default model directory
Returns: Dict | None
"""
if isinstance(model_name, list) and len(model_name) > 0:
model_name = model_name[0]
if model_name == "" or model_name is None:
return None
#model_name = sanitize_name(model_name)
if model_dir:
models_path = model_dir
shared.dreambooth_models_path = models_path
else:
models_path = shared.dreambooth_models_path
if models_path == "" or models_path is None:
models_path = os.path.join(shared.models_path, "dreambooth")
config_file = os.path.join(models_path, model_name, "db_config.json")
try:
with open(config_file, 'r') as openfile:
config_dict = json.load(openfile)
config = DreamboothConfig(model_name)
config.load_params(config_dict)
shared.db_model_config = config
return config
except Exception as e:
print(f"Exception loading config: {e}")
traceback.print_exc()
return None