添加 `save_last_n_epochs_state` 参数 (#611)

* fix: typo

* feat: add save_last_n_epochs_state option

* fix: Schema
pull/669/head
undefined 2025-01-13 17:57:57 +08:00 committed by GitHub
parent c18e95208a
commit 7159f93ab0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 10 deletions

View File

@ -148,7 +148,7 @@ async def create_toml_file(request: Request):
f.write(sample_prompts_arg)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote promopts to file {sample_prompts_file}")
log.info(f"Wrote prompts to file {sample_prompts_file}")
except ValueError as e:
return APIResponseFail(message=str(e))

View File

@ -91,15 +91,23 @@
Schema.object({}),
]),
SAVE_SETTINGS: Schema.object({
output_name: Schema.string().default("aki").description("模型保存名称"),
output_dir: Schema.string().role('filepicker', { type: "folder" }).default("./output").description("模型保存文件夹"),
save_model_as: Schema.union(["safetensors", "pt", "ckpt"]).default("safetensors").description("模型保存格式"),
save_precision: Schema.union(["fp16", "float", "bf16"]).default("fp16").description("模型保存精度"),
save_every_n_epochs: Schema.number().default(2).description("每 N epoch自动保存一次模型"),
save_state: Schema.boolean().description("保存训练状态 配合 `resume` 参数可以继续从某个状态训练"),
}).description("保存设置"),
SAVE_SETTINGS: Schema.intersect([
Schema.object({
output_name: Schema.string().default("aki").description("模型保存名称"),
output_dir: Schema.string().role('filepicker', { type: "folder" }).default("./output").description("模型保存文件夹"),
save_model_as: Schema.union(["safetensors", "pt", "ckpt"]).default("safetensors").description("模型保存格式"),
save_precision: Schema.union(["fp16", "float", "bf16"]).default("fp16").description("模型保存精度"),
save_every_n_epochs: Schema.number().default(2).description("每 N epoch自动保存一次模型"),
save_state: Schema.boolean().default(false).description("保存训练状态 配合 `resume` 参数可以继续从某个状态训练"),
}),
Schema.union([
Schema.object({
save_state: Schema.const(true).required(),
save_last_n_epochs_state: Schema.number().min(1).description("仅保存最后 n epoch 的训练状态"),
}),
Schema.object({})
])
]).description("保存设置"),
LR_OPTIMIZER: Schema.intersect([
Schema.object({
learning_rate: Schema.string().default("1e-4").description("总学习率, 在分开设置 U-Net 与文本编码器学习率后这个值失效。"),