Work sync

pull/1241/head
d8ahazard 2023-05-23 08:27:11 -05:00
parent d45317e23a
commit 7dfd69a5f9
7 changed files with 151 additions and 48 deletions

View File

@ -291,30 +291,6 @@ class DreamboothConfig(BaseModel):
concepts.extend([Concept(None)] * missing)
return concepts
# Set default values
def check_defaults(self):
if self.model_name:
if self.revision == "" or self.revision is None:
self.revision = 0
if self.epoch == "" or self.epoch is None:
self.epoch = 0
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
models_path = shared.dreambooth_models_path
try:
from core.handlers.models import ModelHandler
mh = ModelHandler()
models_path = mh.models_path
except:
pass
if models_path == "" or models_path is None:
models_path = os.path.join(shared.models_path, "dreambooth")
model_dir = os.path.join(models_path, self.model_name)
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
os.makedirs(working_dir)
self.model_dir = model_dir
self.pretrained_model_name_or_path = working_dir
def refresh(self):
"""
Reload self from file

View File

@ -5,7 +5,14 @@ from dreambooth import shared
db_path = os.path.join(shared.models_path, "dreambooth")
secret_file = os.path.join(db_path, "secret.txt")
try:
from core.handlers.config import DirectoryHandler
dh = DirectoryHandler()
protected_path = dh.protected_path
db_path = os.path.join(protected_path, "dreambooth")
secret_file = os.path.join(db_path, "secret.txt")
except:
pass
if not os.path.exists(db_path):
os.makedirs(db_path)

View File

@ -90,18 +90,6 @@ except:
export_diffusers = False
user_model_dir = ""
try:
from core.handlers.config import ConfigHandler
from core.handlers.models import ModelHandler
ch = ConfigHandler()
mh = ModelHandler()
export_diffusers = ch.get_item("export_diffusers", "dreambooth", True)
user_model_dir = os.path.join(mh.models_path[0], "diffusers")
logger.debug(f"Export diffusers: {export_diffusers}, diffusers dir: {user_model_dir}")
except:
pass
def dadapt(optimizer):
@ -156,9 +144,17 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
args = shared.db_model_config
status_handler = None
logging_dir = Path(args.model_dir, "logging")
global export_diffusers, user_model_dir
try:
from core.handlers.status import StatusHandler
from core.handlers.config import ConfigHandler
from core.handlers.models import ModelHandler
mh = ModelHandler(user_name=user)
status_handler = StatusHandler(user_name=user, target="dreamProgress")
export_diffusers = True
user_model_dir = mh.user_path
logger.debug(f"Export diffusers: {export_diffusers}, diffusers dir: {user_model_dir}")
shared.status_handler = status_handler
logger.debug(f"Loaded config: {args.__dict__}")
except:
@ -687,7 +683,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}"
)
if os.path.exists(new_hotness):
accelerator.logger.debug(f"Resuming from checkpoint {new_hotness}")
logger.debug(f"Resuming from checkpoint {new_hotness}")
try:
import modules.shared
@ -864,7 +860,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
requires_safety_checker=None,
)
with accelerator.autocast():
# Is inference_mode() needed here to prevent issues when saving?
with accelerator.autocast(), torch.inference_mode():
if save_model:
# We are saving weights, we need to ensure revision is saved
args.save()

View File

@ -19,7 +19,14 @@ db_path = os.path.join(shared.models_path, "dreambooth")
url_file = os.path.join(db_path, "webhook.txt")
hook_url = None
if not os.path.exists(db_path):
new_ui = False
try:
from core.dataclasses import status_data
new_ui = True
except:
pass
if not os.path.exists(db_path) and not new_ui:
os.makedirs(db_path)

View File

@ -15,7 +15,7 @@ let modelLoaded = false;
const mdBreakpoint = 990;
// Register the module with the UI. Icon is from boxicons by default.
const dbModule = new Module("Dreambooth", "moduleDreambooth", "moon", false, 2, initDreambooth);
const dbModule = new Module("Dreambooth", "moduleDreambooth", "moon", false, 2, initDreambooth, refreshDreambooth);
function initDreambooth() {
sendMessage("get_db_vars", {}, true).then(function (response) {
@ -103,7 +103,7 @@ function initDreambooth() {
}
});
// utility function to convert a string to Title Case
// utility function to convert a string to Title Case
String.prototype.toTitleCase = function () {
return this.replace(/\w\S*/g, function (txt) {
return txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase();
@ -176,6 +176,19 @@ function initDreambooth() {
}
}
function refreshDreambooth() {
dreamConfig = dbModule.systemConfig;
showAdvanced = dreamConfig["show_advanced"];
if (showAdvanced) {
$(".db-advanced").show();
$(".db-basic").hide();
$("#hub_row").hide();
$("#local_row").show();
} else {
$(".db-advanced").hide();
$(".db-basic").show();
}
}
function onDbEnd() {
$(".dbTrainBtn").addClass("hide");
$(".dbSettingBtn").removeClass("hide");

View File

@ -12,7 +12,8 @@ import torch
from fastapi import FastAPI
import scripts.api
from core.handlers.models import ModelHandler
from core.handlers.config import ConfigHandler
from core.handlers.models import ModelHandler, ModelManager
from core.handlers.status import StatusHandler
from core.handlers.websocket import SocketHandler
from core.modules.base.module_base import BaseModule
@ -36,6 +37,10 @@ class DreamboothModule(BaseModule):
def initialize(self, app: FastAPI, handler: SocketHandler):
self._initialize_api(app)
self._initialize_websocket(handler)
defaults_base_file = os.path.join(os.path.dirname(__file__), "templates", "db_config.json")
if os.path.exists(defaults_base_file):
ch = ConfigHandler()
ch.set_config_protected(json.load(open(defaults_base_file, "r")), "dreambooth_model_defaults")
def _initialize_api(self, app: FastAPI):
return scripts.api.dreambooth_api(None, app)
@ -76,8 +81,9 @@ async def _train_dreambooth(request):
user = request["user"] if "user" in request else None
config = await _set_model_config(request, True)
mh = ModelHandler(user_name=user)
mm = ModelManager()
sh = StatusHandler(user_name=user, target="dreamProgress")
mh.to_cpu()
mm.to_cpu()
shared.db_model_config = config
try:
torch.cuda.empty_cache()
@ -92,20 +98,17 @@ async def _train_dreambooth(request):
with ThreadPoolExecutor() as pool:
await loop.run_in_executor(pool, lambda: (
sh.start(0, "Starting Dreambooth Training..."),
main(user=user),
sh.end("Training complete.")
main(user=user)
))
except Exception as e:
logger.error(f"Error in training: {e}")
traceback.print_exc()
result = {"message": f"Error in training: {e}"}
sh.end(f"Error in training: {e}")
try:
gc.collect()
torch.cuda.empty_cache()
except:
pass
mh.to_gpu()
sh.end(result["message"])
return result
@ -161,10 +164,20 @@ async def copy_model(model_name: str, src: str, is_512: bool, mh: ModelHandler,
dest_dir = os.path.join(model_dir, "dreambooth", model_name, "working")
if os.path.exists(dest_dir):
shutil.rmtree(dest_dir, True)
ch = ConfigHandler(user_name=mh.user_name)
base = ch.get_config_protected("dreambooth_model_defaults")
user_base = ch.get_config_user("dreambooth_model_defaults")
logger.debug(f"User base: {user_base}")
if user_base is not None:
base = {**base, **user_base}
else:
logger.debug("Setting user config")
ch.set_config_user(base, "dreambooth_model_defaults")
if not os.path.exists(dest_dir):
logger.debug(f"Copying model from {src} to {dest_dir}")
await copy_directory(src, dest_dir, sh)
cfg = DreamboothConfig(model_name=model_name, src=src, resolution=512 if is_512 else 768, models_path=dreambooth_models_path)
cfg.load_params(base)
cfg.save()
else:

90
templates/db_config.json Normal file
View File

@ -0,0 +1,90 @@
{
"weight_decay": 0.01,
"cache_latents": true,
"clip_skip": 2,
"concepts_list": [],
"concepts_path": "",
"custom_model_name": "",
"deterministic": false,
"disable_class_matching": false,
"disable_logging": false,
"ema_predict": false,
"epoch_pause_frequency": 0,
"epoch_pause_time": 0,
"freeze_clip_normalization": true,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"gradient_set_to_none": true,
"graph_smoothing": 50,
"half_model": false,
"hflip": false,
"infer_ema": false,
"learning_rate": 2e-06,
"learning_rate_min": 1e-06,
"lora_learning_rate": 0.0001,
"lora_txt_learning_rate": 5e-05,
"lora_txt_rank": 4,
"lora_txt_weight": 1.0,
"lora_unet_rank": 4,
"lora_weight": 1.0,
"lora_use_buggy_requires_grad": false,
"lr_cycles": 1,
"lr_factor": 0.5,
"lr_power": 1.0,
"lr_scale_pos": 0.5,
"lr_scheduler": "constant_with_warmup",
"lr_warmup_steps": 500,
"max_token_length": 75,
"mixed_precision": "fp16",
"noise_scheduler": "DDPM",
"num_train_epochs": 200,
"offset_noise": 0,
"optimizer": "8bit AdamW",
"pad_tokens": true,
"pretrained_vae_name_or_path": "",
"prior_loss_scale": false,
"prior_loss_target": 100,
"prior_loss_weight": 0.75,
"prior_loss_weight_min": 0.1,
"resolution": 512,
"revision": 21910,
"sample_batch_size": 1,
"sanity_prompt": "",
"sanity_seed": 420420,
"save_ckpt_after": true,
"save_ckpt_cancel": false,
"save_ckpt_during": true,
"save_ema": true,
"save_embedding_every": 25,
"save_lora_after": true,
"save_lora_cancel": false,
"save_lora_during": true,
"save_lora_for_extra_net": true,
"save_preview_every": 5,
"save_safetensors": true,
"save_state_after": false,
"save_state_cancel": false,
"save_state_during": false,
"scheduler": "UniPCMultistep",
"shuffle_tags": true,
"split_loss": true,
"stop_text_encoder": 0.75,
"strict_tokens": true,
"dynamic_img_norm": false,
"tenc_weight_decay": 0.01,
"tenc_grad_clip_norm": 6,
"tomesd": 0,
"train_batch_size": 1,
"train_imagic": false,
"train_unet": true,
"train_unfrozen": true,
"txt_learning_rate": 1e-06,
"use_concepts": true,
"use_ema": false,
"use_lora": false,
"use_lora_extended": false,
"use_shared_src": [
false
],
"use_subdir": false
}