Fix training start
parent
76e0b97e18
commit
b738ddd42a
|
|
@ -310,7 +310,7 @@ function loadDbListeners() {
|
|||
});
|
||||
|
||||
$("#db_save_config").click(function () {
|
||||
let selected = dreamSelect.getModel();
|
||||
let selected = $("#dreamModelSelect").modelSelect().getModel();
|
||||
if (selected === undefined) {
|
||||
alert("Please select a model first!");
|
||||
} else {
|
||||
|
|
@ -566,7 +566,7 @@ function removeConcept() {
|
|||
|
||||
function getSettings() {
|
||||
let settings = {};
|
||||
settings["model"] = dreamSelect.getModel();
|
||||
settings["model"] = $("#dreamModelSelect").modelSelect().getModel();
|
||||
|
||||
// Just create one concept if advanced is disabled
|
||||
let concepts_list = [];
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class DreamboothModule(BaseModule):
|
|||
return scripts.api.dreambooth_api(None, app)
|
||||
|
||||
def _initialize_websocket(self, handler: SocketHandler):
|
||||
handler.register("train_dreambooth", _start_training)
|
||||
handler.register("train_dreambooth", _train_dreambooth)
|
||||
handler.register("create_dreambooth", _create_model)
|
||||
handler.register("get_db_config", _get_model_config)
|
||||
handler.register("save_db_config", _set_model_config)
|
||||
|
|
@ -68,18 +68,11 @@ async def _get_db_vars(request):
|
|||
}
|
||||
|
||||
|
||||
async def _start_training(request):
|
||||
async def _train_dreambooth(request):
|
||||
user = request["user"] if "user" in request else None
|
||||
target = request["target"] if "target" in request else None
|
||||
config = await _set_model_config(request, True)
|
||||
asyncio.create_task(_train_dreambooth(config, user, target))
|
||||
return {"status": "Training started."}
|
||||
|
||||
|
||||
async def _train_dreambooth(config: DreamboothConfig, user: str = None, target: str = None):
|
||||
logger.debug(f"Updated config: {config.__dict__}")
|
||||
mh = ModelHandler(user_name=user)
|
||||
sh = StatusHandler(user_name=user, target=target)
|
||||
sh = StatusHandler(user_name=user, target="dreamProgress")
|
||||
mh.to_cpu()
|
||||
shared.db_model_config = config
|
||||
try:
|
||||
|
|
@ -93,9 +86,11 @@ async def _train_dreambooth(config: DreamboothConfig, user: str = None, target:
|
|||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with ThreadPoolExecutor() as pool:
|
||||
await loop.run_in_executor(pool, lambda: main(user=user))
|
||||
sh.end("Training complete.")
|
||||
|
||||
await loop.run_in_executor(pool, lambda: (
|
||||
sh.start(0, "Starting Dreambooth Training..."),
|
||||
main(user=user),
|
||||
sh.end("Training complete.")
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
Loading…
Reference in New Issue