Work sync
parent
d45317e23a
commit
7dfd69a5f9
|
|
@ -291,30 +291,6 @@ class DreamboothConfig(BaseModel):
|
|||
concepts.extend([Concept(None)] * missing)
|
||||
return concepts
|
||||
|
||||
# Set default values
|
||||
def check_defaults(self):
|
||||
if self.model_name:
|
||||
if self.revision == "" or self.revision is None:
|
||||
self.revision = 0
|
||||
if self.epoch == "" or self.epoch is None:
|
||||
self.epoch = 0
|
||||
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
|
||||
models_path = shared.dreambooth_models_path
|
||||
try:
|
||||
from core.handlers.models import ModelHandler
|
||||
mh = ModelHandler()
|
||||
models_path = mh.models_path
|
||||
except:
|
||||
pass
|
||||
if models_path == "" or models_path is None:
|
||||
models_path = os.path.join(shared.models_path, "dreambooth")
|
||||
model_dir = os.path.join(models_path, self.model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
if not os.path.exists(working_dir):
|
||||
os.makedirs(working_dir)
|
||||
self.model_dir = model_dir
|
||||
self.pretrained_model_name_or_path = working_dir
|
||||
|
||||
def refresh(self):
|
||||
"""
|
||||
Reload self from file
|
||||
|
|
|
|||
|
|
@ -5,7 +5,14 @@ from dreambooth import shared
|
|||
|
||||
db_path = os.path.join(shared.models_path, "dreambooth")
|
||||
secret_file = os.path.join(db_path, "secret.txt")
|
||||
|
||||
try:
|
||||
from core.handlers.config import DirectoryHandler
|
||||
dh = DirectoryHandler()
|
||||
protected_path = dh.protected_path
|
||||
db_path = os.path.join(protected_path, "dreambooth")
|
||||
secret_file = os.path.join(db_path, "secret.txt")
|
||||
except:
|
||||
pass
|
||||
if not os.path.exists(db_path):
|
||||
os.makedirs(db_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,18 +90,6 @@ except:
|
|||
|
||||
export_diffusers = False
|
||||
user_model_dir = ""
|
||||
try:
|
||||
from core.handlers.config import ConfigHandler
|
||||
from core.handlers.models import ModelHandler
|
||||
|
||||
ch = ConfigHandler()
|
||||
mh = ModelHandler()
|
||||
export_diffusers = ch.get_item("export_diffusers", "dreambooth", True)
|
||||
user_model_dir = os.path.join(mh.models_path[0], "diffusers")
|
||||
logger.debug(f"Export diffusers: {export_diffusers}, diffusers dir: {user_model_dir}")
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def dadapt(optimizer):
|
||||
|
|
@ -156,9 +144,17 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args = shared.db_model_config
|
||||
status_handler = None
|
||||
logging_dir = Path(args.model_dir, "logging")
|
||||
global export_diffusers, user_model_dir
|
||||
try:
|
||||
from core.handlers.status import StatusHandler
|
||||
from core.handlers.config import ConfigHandler
|
||||
from core.handlers.models import ModelHandler
|
||||
|
||||
mh = ModelHandler(user_name=user)
|
||||
status_handler = StatusHandler(user_name=user, target="dreamProgress")
|
||||
export_diffusers = True
|
||||
user_model_dir = mh.user_path
|
||||
logger.debug(f"Export diffusers: {export_diffusers}, diffusers dir: {user_model_dir}")
|
||||
shared.status_handler = status_handler
|
||||
logger.debug(f"Loaded config: {args.__dict__}")
|
||||
except:
|
||||
|
|
@ -687,7 +683,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}"
|
||||
)
|
||||
if os.path.exists(new_hotness):
|
||||
accelerator.logger.debug(f"Resuming from checkpoint {new_hotness}")
|
||||
logger.debug(f"Resuming from checkpoint {new_hotness}")
|
||||
|
||||
try:
|
||||
import modules.shared
|
||||
|
|
@ -864,7 +860,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
requires_safety_checker=None,
|
||||
)
|
||||
|
||||
with accelerator.autocast():
|
||||
# Is inference_mode() needed here to prevent issues when saving?
|
||||
with accelerator.autocast(), torch.inference_mode():
|
||||
if save_model:
|
||||
# We are saving weights, we need to ensure revision is saved
|
||||
args.save()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,14 @@ db_path = os.path.join(shared.models_path, "dreambooth")
|
|||
url_file = os.path.join(db_path, "webhook.txt")
|
||||
hook_url = None
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
new_ui = False
|
||||
try:
|
||||
from core.dataclasses import status_data
|
||||
new_ui = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if not os.path.exists(db_path) and not new_ui:
|
||||
os.makedirs(db_path)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ let modelLoaded = false;
|
|||
const mdBreakpoint = 990;
|
||||
|
||||
// Register the module with the UI. Icon is from boxicons by default.
|
||||
const dbModule = new Module("Dreambooth", "moduleDreambooth", "moon", false, 2, initDreambooth);
|
||||
const dbModule = new Module("Dreambooth", "moduleDreambooth", "moon", false, 2, initDreambooth, refreshDreambooth);
|
||||
|
||||
function initDreambooth() {
|
||||
sendMessage("get_db_vars", {}, true).then(function (response) {
|
||||
|
|
@ -103,7 +103,7 @@ function initDreambooth() {
|
|||
}
|
||||
});
|
||||
|
||||
// utility function to convert a string to Title Case
|
||||
// utility function to convert a string to Title Case
|
||||
String.prototype.toTitleCase = function () {
|
||||
return this.replace(/\w\S*/g, function (txt) {
|
||||
return txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase();
|
||||
|
|
@ -176,6 +176,19 @@ function initDreambooth() {
|
|||
}
|
||||
}
|
||||
|
||||
function refreshDreambooth() {
|
||||
dreamConfig = dbModule.systemConfig;
|
||||
showAdvanced = dreamConfig["show_advanced"];
|
||||
if (showAdvanced) {
|
||||
$(".db-advanced").show();
|
||||
$(".db-basic").hide();
|
||||
$("#hub_row").hide();
|
||||
$("#local_row").show();
|
||||
} else {
|
||||
$(".db-advanced").hide();
|
||||
$(".db-basic").show();
|
||||
}
|
||||
}
|
||||
function onDbEnd() {
|
||||
$(".dbTrainBtn").addClass("hide");
|
||||
$(".dbSettingBtn").removeClass("hide");
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ import torch
|
|||
from fastapi import FastAPI
|
||||
|
||||
import scripts.api
|
||||
from core.handlers.models import ModelHandler
|
||||
from core.handlers.config import ConfigHandler
|
||||
from core.handlers.models import ModelHandler, ModelManager
|
||||
from core.handlers.status import StatusHandler
|
||||
from core.handlers.websocket import SocketHandler
|
||||
from core.modules.base.module_base import BaseModule
|
||||
|
|
@ -36,6 +37,10 @@ class DreamboothModule(BaseModule):
|
|||
def initialize(self, app: FastAPI, handler: SocketHandler):
|
||||
self._initialize_api(app)
|
||||
self._initialize_websocket(handler)
|
||||
defaults_base_file = os.path.join(os.path.dirname(__file__), "templates", "db_config.json")
|
||||
if os.path.exists(defaults_base_file):
|
||||
ch = ConfigHandler()
|
||||
ch.set_config_protected(json.load(open(defaults_base_file, "r")), "dreambooth_model_defaults")
|
||||
|
||||
def _initialize_api(self, app: FastAPI):
|
||||
return scripts.api.dreambooth_api(None, app)
|
||||
|
|
@ -76,8 +81,9 @@ async def _train_dreambooth(request):
|
|||
user = request["user"] if "user" in request else None
|
||||
config = await _set_model_config(request, True)
|
||||
mh = ModelHandler(user_name=user)
|
||||
mm = ModelManager()
|
||||
sh = StatusHandler(user_name=user, target="dreamProgress")
|
||||
mh.to_cpu()
|
||||
mm.to_cpu()
|
||||
shared.db_model_config = config
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -92,20 +98,17 @@ async def _train_dreambooth(request):
|
|||
with ThreadPoolExecutor() as pool:
|
||||
await loop.run_in_executor(pool, lambda: (
|
||||
sh.start(0, "Starting Dreambooth Training..."),
|
||||
main(user=user),
|
||||
sh.end("Training complete.")
|
||||
main(user=user)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training: {e}")
|
||||
traceback.print_exc()
|
||||
result = {"message": f"Error in training: {e}"}
|
||||
sh.end(f"Error in training: {e}")
|
||||
try:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
mh.to_gpu()
|
||||
sh.end(result["message"])
|
||||
return result
|
||||
|
||||
|
|
@ -161,10 +164,20 @@ async def copy_model(model_name: str, src: str, is_512: bool, mh: ModelHandler,
|
|||
dest_dir = os.path.join(model_dir, "dreambooth", model_name, "working")
|
||||
if os.path.exists(dest_dir):
|
||||
shutil.rmtree(dest_dir, True)
|
||||
ch = ConfigHandler(user_name=mh.user_name)
|
||||
base = ch.get_config_protected("dreambooth_model_defaults")
|
||||
user_base = ch.get_config_user("dreambooth_model_defaults")
|
||||
logger.debug(f"User base: {user_base}")
|
||||
if user_base is not None:
|
||||
base = {**base, **user_base}
|
||||
else:
|
||||
logger.debug("Setting user config")
|
||||
ch.set_config_user(base, "dreambooth_model_defaults")
|
||||
if not os.path.exists(dest_dir):
|
||||
logger.debug(f"Copying model from {src} to {dest_dir}")
|
||||
await copy_directory(src, dest_dir, sh)
|
||||
cfg = DreamboothConfig(model_name=model_name, src=src, resolution=512 if is_512 else 768, models_path=dreambooth_models_path)
|
||||
cfg.load_params(base)
|
||||
cfg.save()
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
{
|
||||
"weight_decay": 0.01,
|
||||
"cache_latents": true,
|
||||
"clip_skip": 2,
|
||||
"concepts_list": [],
|
||||
"concepts_path": "",
|
||||
"custom_model_name": "",
|
||||
"deterministic": false,
|
||||
"disable_class_matching": false,
|
||||
"disable_logging": false,
|
||||
"ema_predict": false,
|
||||
"epoch_pause_frequency": 0,
|
||||
"epoch_pause_time": 0,
|
||||
"freeze_clip_normalization": true,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": true,
|
||||
"gradient_set_to_none": true,
|
||||
"graph_smoothing": 50,
|
||||
"half_model": false,
|
||||
"hflip": false,
|
||||
"infer_ema": false,
|
||||
"learning_rate": 2e-06,
|
||||
"learning_rate_min": 1e-06,
|
||||
"lora_learning_rate": 0.0001,
|
||||
"lora_txt_learning_rate": 5e-05,
|
||||
"lora_txt_rank": 4,
|
||||
"lora_txt_weight": 1.0,
|
||||
"lora_unet_rank": 4,
|
||||
"lora_weight": 1.0,
|
||||
"lora_use_buggy_requires_grad": false,
|
||||
"lr_cycles": 1,
|
||||
"lr_factor": 0.5,
|
||||
"lr_power": 1.0,
|
||||
"lr_scale_pos": 0.5,
|
||||
"lr_scheduler": "constant_with_warmup",
|
||||
"lr_warmup_steps": 500,
|
||||
"max_token_length": 75,
|
||||
"mixed_precision": "fp16",
|
||||
"noise_scheduler": "DDPM",
|
||||
"num_train_epochs": 200,
|
||||
"offset_noise": 0,
|
||||
"optimizer": "8bit AdamW",
|
||||
"pad_tokens": true,
|
||||
"pretrained_vae_name_or_path": "",
|
||||
"prior_loss_scale": false,
|
||||
"prior_loss_target": 100,
|
||||
"prior_loss_weight": 0.75,
|
||||
"prior_loss_weight_min": 0.1,
|
||||
"resolution": 512,
|
||||
"revision": 21910,
|
||||
"sample_batch_size": 1,
|
||||
"sanity_prompt": "",
|
||||
"sanity_seed": 420420,
|
||||
"save_ckpt_after": true,
|
||||
"save_ckpt_cancel": false,
|
||||
"save_ckpt_during": true,
|
||||
"save_ema": true,
|
||||
"save_embedding_every": 25,
|
||||
"save_lora_after": true,
|
||||
"save_lora_cancel": false,
|
||||
"save_lora_during": true,
|
||||
"save_lora_for_extra_net": true,
|
||||
"save_preview_every": 5,
|
||||
"save_safetensors": true,
|
||||
"save_state_after": false,
|
||||
"save_state_cancel": false,
|
||||
"save_state_during": false,
|
||||
"scheduler": "UniPCMultistep",
|
||||
"shuffle_tags": true,
|
||||
"split_loss": true,
|
||||
"stop_text_encoder": 0.75,
|
||||
"strict_tokens": true,
|
||||
"dynamic_img_norm": false,
|
||||
"tenc_weight_decay": 0.01,
|
||||
"tenc_grad_clip_norm": 6,
|
||||
"tomesd": 0,
|
||||
"train_batch_size": 1,
|
||||
"train_imagic": false,
|
||||
"train_unet": true,
|
||||
"train_unfrozen": true,
|
||||
"txt_learning_rate": 1e-06,
|
||||
"use_concepts": true,
|
||||
"use_ema": false,
|
||||
"use_lora": false,
|
||||
"use_lora_extended": false,
|
||||
"use_shared_src": [
|
||||
false
|
||||
],
|
||||
"use_subdir": false
|
||||
}
|
||||
Loading…
Reference in New Issue