408 lines
14 KiB
Python
408 lines
14 KiB
Python
import json
|
|
import os
|
|
import traceback
|
|
from typing import List, Dict
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from dreambooth import shared # noqa
|
|
from dreambooth.dataclasses.db_concept import Concept # noqa
|
|
from dreambooth.utils.image_utils import get_scheduler_names # noqa
|
|
from dreambooth.utils.utils import list_attention
|
|
|
|
# Keys to save, replacing our dumb __init__ method
|
|
save_keys = []
|
|
|
|
# Keys to return to the ui when Load Settings is clicked.
|
|
ui_keys = []
|
|
|
|
|
|
def sanitize_name(name):
|
|
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
|
|
|
|
|
class DreamboothConfig(BaseModel):
|
|
# These properties MUST be sorted alphabetically
|
|
adamw_weight_decay: float = 0.01
|
|
attention: str = "xformers"
|
|
cache_latents: bool = True
|
|
clip_skip: int = 1
|
|
concepts_list: List[Dict] = []
|
|
concepts_path: str = ""
|
|
custom_model_name: str = ""
|
|
deterministic: bool = False
|
|
disable_logging: bool = False
|
|
ema_predict: bool = False
|
|
enable_tomesd: bool = False
|
|
epoch: int = 0
|
|
epoch_pause_frequency: int = 0
|
|
epoch_pause_time: int = 0
|
|
freeze_clip_normalization: bool = False
|
|
gradient_accumulation_steps: int = 1
|
|
gradient_checkpointing: bool = True
|
|
gradient_set_to_none: bool = True
|
|
graph_smoothing: int = 50
|
|
half_model: bool = False
|
|
has_ema: bool = False
|
|
hflip: bool = False
|
|
infer_ema: bool = False
|
|
initial_revision: int = 0
|
|
learning_rate: float = 5e-6
|
|
learning_rate_min: float = 1e-6
|
|
lifetime_revision: int = 0
|
|
lora_learning_rate: float = 1e-4
|
|
lora_model_name: str = ""
|
|
lora_txt_learning_rate: float = 5e-5
|
|
lora_txt_rank: int = 4
|
|
lora_txt_weight: float = 1.0
|
|
lora_unet_rank: int = 4
|
|
lora_weight: float = 1.0
|
|
lr_cycles: int = 1
|
|
lr_factor: float = 0.5
|
|
lr_power: float = 1.0
|
|
lr_scale_pos: float = 0.5
|
|
lr_scheduler: str = "constant_with_warmup"
|
|
lr_warmup_steps: int = 0
|
|
max_token_length: int = 75
|
|
mixed_precision: str = "fp16"
|
|
model_dir: str = ""
|
|
model_name: str = ""
|
|
model_path: str = ""
|
|
noise_scheduler: str = "DDPM"
|
|
num_train_epochs: int = 100
|
|
offset_noise: float = 0
|
|
optimizer: str = "8bit AdamW"
|
|
pad_tokens: bool = True
|
|
pretrained_model_name_or_path: str = ""
|
|
pretrained_vae_name_or_path: str = ""
|
|
prior_loss_scale: bool = False
|
|
prior_loss_target: int = 100
|
|
prior_loss_weight: float = 0.75
|
|
prior_loss_weight_min: float = 0.1
|
|
resolution: int = 512
|
|
revision: int = 0
|
|
sample_batch_size: int = 1
|
|
sanity_prompt: str = ""
|
|
sanity_seed: int = 420420
|
|
save_ckpt_after: bool = True
|
|
save_ckpt_cancel: bool = False
|
|
save_ckpt_during: bool = True
|
|
save_ema: bool = True
|
|
save_embedding_every: int = 25
|
|
save_lora_after: bool = True
|
|
save_lora_cancel: bool = False
|
|
save_lora_during: bool = True
|
|
save_lora_for_extra_net: bool = True
|
|
save_preview_every: int = 5
|
|
save_safetensors: bool = True
|
|
save_state_after: bool = False
|
|
save_state_cancel: bool = False
|
|
save_state_during: bool = False
|
|
scheduler: str = "ddim"
|
|
shared_diffusers_path: str = ""
|
|
shuffle_tags: bool = True
|
|
snapshot: str = ""
|
|
split_loss: bool = True
|
|
src: str = ""
|
|
stop_text_encoder: float = 1.0
|
|
strict_tokens: bool = False
|
|
tenc_weight_decay: float = 0.01
|
|
tenc_grad_clip_norm: float = 0.00
|
|
tf32_enable: bool = False
|
|
train_batch_size: int = 1
|
|
train_imagic: bool = False
|
|
train_unet: bool = True
|
|
train_unfrozen: bool = True
|
|
use_concepts: bool = False
|
|
use_ema: bool = True
|
|
use_lora: bool = False
|
|
use_lora_extended: bool = False
|
|
use_shared_src: bool = False,
|
|
use_subdir: bool = False
|
|
v2: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "",
|
|
model_dir: str = "",
|
|
v2: bool = False,
|
|
src: str = "",
|
|
resolution: int = 512,
|
|
**kwargs
|
|
):
|
|
|
|
super().__init__(**kwargs)
|
|
model_name = sanitize_name(model_name)
|
|
models_path = shared.dreambooth_models_path
|
|
if models_path == "" or models_path is None:
|
|
models_path = os.path.join(shared.models_path, "dreambooth")
|
|
|
|
# If we're using the new UI, this should be populated, so load models from here.
|
|
if len(shared.paths):
|
|
models_path = os.path.join(shared.paths["models"], "dreambooth")
|
|
|
|
if not self.use_lora:
|
|
self.lora_model_name = ""
|
|
|
|
model_dir = os.path.join(models_path, model_name)
|
|
# print(f"Model dir set to: {model_dir}")
|
|
working_dir = os.path.join(model_dir, "working")
|
|
|
|
if not os.path.exists(working_dir):
|
|
os.makedirs(working_dir)
|
|
|
|
self.model_name = model_name
|
|
self.model_dir = model_dir
|
|
self.pretrained_model_name_or_path = working_dir
|
|
self.resolution = resolution
|
|
self.src = src
|
|
self.scheduler = "ddim"
|
|
self.v2 = v2
|
|
|
|
# Actually save as a file
|
|
def save(self, backup=False):
|
|
"""
|
|
Save the config file
|
|
"""
|
|
models_path = self.model_dir
|
|
config_file = os.path.join(models_path, "db_config.json")
|
|
if backup:
|
|
backup_dir = os.path.join(models_path, "backups")
|
|
if not os.path.exists(backup_dir):
|
|
os.makedirs(backup_dir)
|
|
config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json")
|
|
with open(config_file, "w") as outfile:
|
|
json.dump(self.__dict__, outfile, indent=4)
|
|
|
|
def load_params(self, params_dict):
|
|
sched_swap = False
|
|
for key, value in params_dict.items():
|
|
if "db_" in key:
|
|
key = key.replace("db_", "")
|
|
if key == "attention" and value == "flash_attention":
|
|
value = list_attention()[-1]
|
|
print(f"Replacing flash attention in config to {value}")
|
|
|
|
if key == "scheduler":
|
|
schedulers = get_scheduler_names()
|
|
if value not in schedulers:
|
|
sched_swap = True
|
|
for scheduler in schedulers:
|
|
if value.lower() in scheduler.lower():
|
|
print(f"Updating scheduler name to: {scheduler}")
|
|
value = scheduler
|
|
break
|
|
|
|
if hasattr(self, key):
|
|
key, value = self.validate_param(key, value)
|
|
setattr(self, key, value)
|
|
if sched_swap:
|
|
self.save()
|
|
|
|
@staticmethod
|
|
def validate_param(key, value):
|
|
replaced_params = {
|
|
# "old_key" : {
|
|
# "new_key": "...",
|
|
# "values": [{
|
|
# "old": ["...", "..."]
|
|
# "new": "..."
|
|
# }]
|
|
# }
|
|
"deis_train_scheduler": {
|
|
"new_key": "noise_scheduler",
|
|
"values": [{
|
|
"old": [True],
|
|
"new": "DDPM"
|
|
}],
|
|
},
|
|
"optimizer": {
|
|
"values": [{
|
|
"old": ["8Bit Adam"],
|
|
"new": "8bit AdamW"
|
|
}],
|
|
},
|
|
"save_safetensors": {
|
|
"values": [{
|
|
"old": [False],
|
|
"new": True
|
|
}],
|
|
}
|
|
}
|
|
|
|
if key in replaced_params.keys():
|
|
replacement = replaced_params[key]
|
|
if hasattr(replacement, "new_key"):
|
|
key = replacement["new_key"]
|
|
if hasattr(replacement, "values"):
|
|
for _value in replacement["values"]:
|
|
if value in _value["old"]:
|
|
value = _value["new"]
|
|
return key, value
|
|
|
|
# Pass a dict and return a list of Concept objects
|
|
def concepts(self, required: int = -1):
|
|
concepts = []
|
|
c_idx = 0
|
|
# If using a file for concepts and not requesting from UI, load from file
|
|
if self.use_concepts and self.concepts_path and required == -1:
|
|
concepts_list = concepts_from_file(self.concepts_path)
|
|
|
|
# Otherwise, use 'stored' list
|
|
else:
|
|
concepts_list = self.concepts_list
|
|
if required == -1:
|
|
required = len(concepts_list)
|
|
|
|
for concept_dict in concepts_list:
|
|
concept = Concept(input_dict=concept_dict)
|
|
if concept.is_valid:
|
|
if concept.class_data_dir == "" or concept.class_data_dir is None:
|
|
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
|
|
concepts.append(concept)
|
|
c_idx += 1
|
|
|
|
missing = len(concepts) - required
|
|
if missing > 0:
|
|
concepts.extend([Concept(None)] * missing)
|
|
return concepts
|
|
|
|
# Set default values
|
|
def check_defaults(self):
|
|
if self.model_name:
|
|
if self.revision == "" or self.revision is None:
|
|
self.revision = 0
|
|
if self.epoch == "" or self.epoch is None:
|
|
self.epoch = 0
|
|
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
|
|
models_path = shared.dreambooth_models_path
|
|
try:
|
|
from core.handlers.models import ModelHandler
|
|
mh = ModelHandler()
|
|
models_path = mh.models_path
|
|
except:
|
|
pass
|
|
if models_path == "" or models_path is None:
|
|
models_path = os.path.join(shared.models_path, "dreambooth")
|
|
model_dir = os.path.join(models_path, self.model_name)
|
|
working_dir = os.path.join(model_dir, "working")
|
|
if not os.path.exists(working_dir):
|
|
os.makedirs(working_dir)
|
|
self.model_dir = model_dir
|
|
self.pretrained_model_name_or_path = working_dir
|
|
|
|
def refresh(self):
|
|
"""
|
|
Reload self from file
|
|
|
|
"""
|
|
models_path = shared.dreambooth_models_path
|
|
if models_path == "" or models_path is None:
|
|
models_path = os.path.join(shared.models_path, "dreambooth")
|
|
config_file = os.path.join(models_path, self.model_name, "db_config.json")
|
|
try:
|
|
with open(config_file, 'r') as openfile:
|
|
config_dict = json.load(openfile)
|
|
|
|
self.load_params(config_dict)
|
|
shared.db_model_config = self
|
|
except Exception as e:
|
|
print(f"Exception loading config: {e}")
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
def get_pretrained_model_name_or_path(self):
|
|
if self.shared_diffusers_path != "" and not self.use_lora:
|
|
raise Exception(f"shared_diffusers_path is \"{self.shared_diffusers_path}\" but use_lora is false")
|
|
return self.shared_diffusers_path if self.shared_diffusers_path != "" else self.pretrained_model_name_or_path
|
|
|
|
|
|
def concepts_from_file(concepts_path: str):
|
|
concepts = []
|
|
if os.path.exists(concepts_path) and os.path.isfile(concepts_path):
|
|
try:
|
|
with open(concepts_path, "r") as concepts_file:
|
|
concepts_str = concepts_file.read()
|
|
except Exception as e:
|
|
print(f"Exception opening concepts file: {e}")
|
|
else:
|
|
concepts_str = concepts_path
|
|
|
|
try:
|
|
concepts_data = json.loads(concepts_str)
|
|
for concept_data in concepts_data:
|
|
concept = Concept(input_dict=concept_data)
|
|
if concept.is_valid:
|
|
concepts.append(concept.__dict__)
|
|
except Exception as e:
|
|
print(f"Exception parsing concepts: {e}")
|
|
return concepts
|
|
|
|
|
|
def save_config(*args):
|
|
params = list(args)
|
|
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
|
|
model_name = params[38]
|
|
if model_name is None or model_name == "":
|
|
print("Invalid model name.")
|
|
return
|
|
params_dict = dict(zip(save_keys, params))
|
|
concepts_list = []
|
|
# If using a concepts file/string, keep concepts_list empty.
|
|
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
|
|
concepts_list = []
|
|
params_dict["concepts_list"] = concepts_list
|
|
else:
|
|
for concept_key in concept_keys:
|
|
concept_dict = {}
|
|
for key, param in params_dict.items():
|
|
if concept_key in key and param is not None:
|
|
concept_dict[key.replace(concept_key, "")] = param
|
|
concept_test = Concept(concept_dict)
|
|
if concept_test.is_valid:
|
|
concepts_list.append(concept_test.__dict__)
|
|
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
|
|
if len(concepts_list) and not len(existing_concepts):
|
|
params_dict["concepts_list"] = concepts_list
|
|
|
|
config = from_file(model_name)
|
|
if config is None:
|
|
config = DreamboothConfig(model_name)
|
|
config.load_params(params_dict)
|
|
shared.db_model_config = config
|
|
config.save()
|
|
|
|
|
|
def from_file(model_name):
|
|
"""
|
|
Load config data from UI
|
|
Args:
|
|
model_name: The config to load
|
|
|
|
Returns: Dict | None
|
|
|
|
"""
|
|
if isinstance(model_name, list) and len(model_name) > 0:
|
|
model_name = model_name[0]
|
|
|
|
if model_name == "" or model_name is None:
|
|
return None
|
|
|
|
#model_name = sanitize_name(model_name)
|
|
models_path = shared.dreambooth_models_path
|
|
if models_path == "" or models_path is None:
|
|
models_path = os.path.join(shared.models_path, "dreambooth")
|
|
config_file = os.path.join(models_path, model_name, "db_config.json")
|
|
try:
|
|
with open(config_file, 'r') as openfile:
|
|
config_dict = json.load(openfile)
|
|
|
|
config = DreamboothConfig(model_name)
|
|
config.load_params(config_dict)
|
|
shared.db_model_config = config
|
|
return config
|
|
except Exception as e:
|
|
print(f"Exception loading config: {e}")
|
|
traceback.print_exc()
|
|
return None
|