From 3ec04aeb9d8a5b6ebbac6da416d88134d12726ed Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 6 Feb 2025 21:47:27 +0800 Subject: [PATCH] sd3 support --- mikazuki/app/api.py | 10 ++--- mikazuki/schema/sd3-lora.ts | 90 +++++++++++++++++++++++++++++++++++++ mikazuki/schema/shared.ts | 1 + 3 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 mikazuki/schema/sd3-lora.ts diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index 68c887e..16b333f 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -94,11 +94,11 @@ def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]: positive_prompts = config.pop('positive_prompts', None) negative_prompts = config.pop('negative_prompts', '') - sample_width = config.pop('sample_width', 512) # 默认宽度 512 - sample_height = config.pop('sample_height', 512) # 默认高度 512 - sample_cfg = config.pop('sample_cfg', 7) # 默认 CFG 值 7.5 - sample_seed = config.pop('sample_seed', 2333) # 默认随机种子 42 - sample_steps = config.pop('sample_steps', 24) # 默认步数 50 + sample_width = config.pop('sample_width', 512) + sample_height = config.pop('sample_height', 512) + sample_cfg = config.pop('sample_cfg', 7) + sample_seed = config.pop('sample_seed', 2333) + sample_steps = config.pop('sample_steps', 24) randomly_choice_prompt = config.pop('randomly_choice_prompt', False) if randomly_choice_prompt: diff --git a/mikazuki/schema/sd3-lora.ts b/mikazuki/schema/sd3-lora.ts new file mode 100644 index 0000000..aa033b9 --- /dev/null +++ b/mikazuki/schema/sd3-lora.ts @@ -0,0 +1,90 @@ +Schema.intersect([ + Schema.object({ + model_train_type: Schema.string().default("sd3-lora").disabled().description("训练种类"), + pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("SD3 模型路径"), + clip_l: Schema.string().role('filepicker', { type: "model-file" }).description("clip_l 模型文件路径"), + clip_g: Schema.string().role('filepicker', { type: "model-file" }).description("clip_g 模型文件路径"), + t5xxl: Schema.string().role('filepicker', { type: "model-file" }).description("t5xxl 模型文件路径"), + resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"), + }).description("训练用模型"), + + Schema.object({ + t5xxl_max_token_length: Schema.number().step(1).description("T5XXL 最大 token 长度(不填写使用自动)"), + train_t5xxl: Schema.boolean().default(false).description("训练 T5XXL(不推荐)"), + }).description("SD3 专用参数"), + + Schema.object( + UpdateSchema(SHARED_SCHEMAS.RAW.DATASET_SETTINGS, { + resolution: Schema.string().default("768,768").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"), + enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"), + min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"), + max_bucket_reso: Schema.number().default(2048).description("arb 桶最大分辨率"), + bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位"), + }) + ).description("数据集设置"), + + // 保存设置 + SHARED_SCHEMAS.SAVE_SETTINGS, + + Schema.object({ + max_train_epochs: Schema.number().min(1).default(20).description("最大训练 epoch(轮数)"), + train_batch_size: Schema.number().min(1).default(1).description("批量大小, 越高显存占用越高"), + gradient_checkpointing: Schema.boolean().default(true).description("梯度检查点"), + gradient_accumulation_steps: Schema.number().min(1).default(1).description("梯度累加步数"), + network_train_unet_only: Schema.boolean().default(true).description("仅训练 U-Net"), + network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"), + }).description("训练相关参数"), + + // 学习率&优化器设置 + SHARED_SCHEMAS.LR_OPTIMIZER, + + Schema.intersect([ + Schema.object({ + network_module: Schema.union(["networks.lora_sd3", "lycoris.kohya"]).default("networks.lora_sd3").description("训练网络模块"), + network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"), + network_dim: Schema.number().min(1).default(4).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"), + network_alpha: Schema.number().min(1).default(1).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"), + network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'), + enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'), + }).description("网络设置"), + + // lycoris 参数 + SHARED_SCHEMAS.LYCORIS_MAIN, + SHARED_SCHEMAS.LYCORIS_LOKR, + + SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT, + ]), + + // 预览图设置 + SHARED_SCHEMAS.PREVIEW_IMAGE, + + // 日志设置 + SHARED_SCHEMAS.LOG_SETTINGS, + + // caption 选项 + // FLUX 去除 max_token_length + Schema.object(UpdateSchema(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS, {}, ["max_token_length"])).description("caption(Tag)选项"), + + // 噪声设置 + SHARED_SCHEMAS.NOISE_SETTINGS, + + // 数据增强 + SHARED_SCHEMAS.DATA_ENCHANCEMENT, + + // 其他选项 + SHARED_SCHEMAS.OTHER, + + // 速度优化选项 + Schema.object( + UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, { + fp8_base: Schema.boolean().default(true).description("对基础模型使用 FP8 精度"), + fp8_base_unet: Schema.boolean().description("仅对 U-Net 使用 FP8 精度(CLIP-L不使用)"), + sdpa: Schema.boolean().default(true).description("启用 sdpa"), + cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"), + cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"), + }, ["xformers"]) + ).description("速度优化选项"), + + // 分布式训练 + SHARED_SCHEMAS.DISTRIBUTED_TRAINING +]); diff --git a/mikazuki/schema/shared.ts b/mikazuki/schema/shared.ts index e6aca1e..bc719f2 100644 --- a/mikazuki/schema/shared.ts +++ b/mikazuki/schema/shared.ts @@ -123,6 +123,7 @@ "constant_with_warmup", ]).default("cosine_with_restarts").description("学习率调度器设置"), lr_warmup_steps: Schema.number().default(0).description('学习率预热步数'), + loss_type: Schema.union(["l1", "l2", "huber", "smooth_l1"]).description("损失函数类型"), }).description("学习率与优化器设置"), Schema.union([