Fix training start

pull/1241/head
d8ahazard 2023-05-11 16:24:32 -05:00
parent 76e0b97e18
commit b738ddd42a
2 changed files with 10 additions and 15 deletions

View File

@ -310,7 +310,7 @@ function loadDbListeners() {
});
$("#db_save_config").click(function () {
let selected = dreamSelect.getModel();
let selected = $("#dreamModelSelect").modelSelect().getModel();
if (selected === undefined) {
alert("Please select a model first!");
} else {
@ -566,7 +566,7 @@ function removeConcept() {
function getSettings() {
let settings = {};
settings["model"] = dreamSelect.getModel();
settings["model"] = $("#dreamModelSelect").modelSelect().getModel();
// Just create one concept if advanced is disabled
let concepts_list = [];

View File

@ -41,7 +41,7 @@ class DreamboothModule(BaseModule):
return scripts.api.dreambooth_api(None, app)
def _initialize_websocket(self, handler: SocketHandler):
handler.register("train_dreambooth", _start_training)
handler.register("train_dreambooth", _train_dreambooth)
handler.register("create_dreambooth", _create_model)
handler.register("get_db_config", _get_model_config)
handler.register("save_db_config", _set_model_config)
@ -68,18 +68,11 @@ async def _get_db_vars(request):
}
async def _start_training(request):
async def _train_dreambooth(request):
user = request["user"] if "user" in request else None
target = request["target"] if "target" in request else None
config = await _set_model_config(request, True)
asyncio.create_task(_train_dreambooth(config, user, target))
return {"status": "Training started."}
async def _train_dreambooth(config: DreamboothConfig, user: str = None, target: str = None):
logger.debug(f"Updated config: {config.__dict__}")
mh = ModelHandler(user_name=user)
sh = StatusHandler(user_name=user, target=target)
sh = StatusHandler(user_name=user, target="dreamProgress")
mh.to_cpu()
shared.db_model_config = config
try:
@ -93,9 +86,11 @@ async def _train_dreambooth(config: DreamboothConfig, user: str = None, target:
try:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as pool:
await loop.run_in_executor(pool, lambda: main(user=user))
sh.end("Training complete.")
await loop.run_in_executor(pool, lambda: (
sh.start(0, "Starting Dreambooth Training..."),
main(user=user),
sh.end("Training complete.")
))
except Exception as e:
logger.error(f"Error in training: {e}")
traceback.print_exc()