diff --git a/js/module_dreambooth.js b/js/module_dreambooth.js index 3174ac2..b0c89b1 100644 --- a/js/module_dreambooth.js +++ b/js/module_dreambooth.js @@ -310,7 +310,7 @@ function loadDbListeners() { }); $("#db_save_config").click(function () { - let selected = dreamSelect.getModel(); + let selected = $("#dreamModelSelect").modelSelect().getModel(); if (selected === undefined) { alert("Please select a model first!"); } else { @@ -566,7 +566,7 @@ function removeConcept() { function getSettings() { let settings = {}; - settings["model"] = dreamSelect.getModel(); + settings["model"] = $("#dreamModelSelect").modelSelect().getModel(); // Just create one concept if advanced is disabled let concepts_list = []; diff --git a/module_dreambooth.py b/module_dreambooth.py index d9b6103..04f363d 100644 --- a/module_dreambooth.py +++ b/module_dreambooth.py @@ -41,7 +41,7 @@ class DreamboothModule(BaseModule): return scripts.api.dreambooth_api(None, app) def _initialize_websocket(self, handler: SocketHandler): - handler.register("train_dreambooth", _start_training) + handler.register("train_dreambooth", _train_dreambooth) handler.register("create_dreambooth", _create_model) handler.register("get_db_config", _get_model_config) handler.register("save_db_config", _set_model_config) @@ -68,18 +68,11 @@ async def _get_db_vars(request): } -async def _start_training(request): +async def _train_dreambooth(request): user = request["user"] if "user" in request else None - target = request["target"] if "target" in request else None config = await _set_model_config(request, True) - asyncio.create_task(_train_dreambooth(config, user, target)) - return {"status": "Training started."} - - -async def _train_dreambooth(config: DreamboothConfig, user: str = None, target: str = None): - logger.debug(f"Updated config: {config.__dict__}") mh = ModelHandler(user_name=user) - sh = StatusHandler(user_name=user, target=target) + sh = StatusHandler(user_name=user, target="dreamProgress") mh.to_cpu() shared.db_model_config = config try: @@ -93,9 +86,11 @@ async def _train_dreambooth(config: DreamboothConfig, user: str = None, target: try: loop = asyncio.get_event_loop() with ThreadPoolExecutor() as pool: - await loop.run_in_executor(pool, lambda: main(user=user)) - sh.end("Training complete.") - + await loop.run_in_executor(pool, lambda: ( + sh.start(0, "Starting Dreambooth Training..."), + main(user=user), + sh.end("Training complete.") + )) except Exception as e: logger.error(f"Error in training: {e}") traceback.print_exc()