pull/608/head
Akegarasu 2024-12-01 20:07:15 +08:00
parent 293a9dc17e
commit 950d4fdbda
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
4 changed files with 30 additions and 20 deletions

View File

@ -80,7 +80,15 @@ Schema.intersect([
SHARED_SCHEMAS.OTHER,
// 速度优化选项
SHARED_SCHEMAS.PRECISION_CACHE_BATCH,
Schema.object(
UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, {
fp8_base: Schema.boolean().default(true).description("对基础模型使用 FP8 精度"),
fp8_base_unet: Schema.boolean().description("仅对 U-Net 使用 FP8 精度CLIP-L不使用"),
sdpa: Schema.boolean().default(true).description("启用 sdpa"),
cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
}, ["xformers"])
).description("速度优化选项"),
// 分布式训练
SHARED_SCHEMAS.DISTRIBUTED_TRAINING

View File

@ -86,7 +86,7 @@ Schema.intersect([
SHARED_SCHEMAS.OTHER,
// 速度优化选项
SHARED_SCHEMAS.PRECISION_CACHE_BATCH,
Schema.object(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH).description("速度优化选项"),
// 分布式训练
SHARED_SCHEMAS.DISTRIBUTED_TRAINING

View File

@ -25,6 +25,21 @@
caption_dropout_rate: Schema.number().min(0).step(0.01).description("丢弃全部标签的概率,对一个图片概率不使用 caption 或 class token"),
caption_dropout_every_n_epochs: Schema.number().min(0).max(100).step(1).description("每 N 个 epoch 丢弃全部标签"),
caption_tag_dropout_rate: Schema.number().min(0).step(0.01).description("按逗号分隔的标签来随机丢弃 tag 的概率"),
},
PRECISION_CACHE_BATCH: {
mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("bf16").description("训练混合精度, RTX30系列以后也可以指定`bf16`"),
full_fp16: Schema.boolean().description("完全使用 FP16 精度"),
full_bf16: Schema.boolean().description("完全使用 BF16 精度"),
no_half_vae: Schema.boolean().description("不使用半精度 VAE"),
xformers: Schema.boolean().default(true).description("启用 xformers"),
sdpa: Schema.boolean().description("启用 sdpa"),
lowram: Schema.boolean().default(false).description("低内存模式 该模式下会将 U-net、文本编码器、VAE 直接加载到显存中"),
cache_latents: Schema.boolean().default(true).description("缓存图像 latent, 缓存 VAE 输出以减少 VRAM 使用"),
cache_latents_to_disk: Schema.boolean().default(true).description("缓存图像 latent 到磁盘"),
cache_text_encoder_outputs: Schema.boolean().description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
cache_text_encoder_outputs_to_disk: Schema.boolean().description("缓存文本编码器的输出到磁盘"),
persistent_data_loader_workers: Schema.boolean().default(true).description("保留加载训练集的worker减少每个 epoch 之间的停顿。"),
vae_batch_size: Schema.number().min(1).description("vae 编码批量大小"),
}
},
@ -190,23 +205,6 @@
ui_custom_params: Schema.string().role('textarea').description("**危险** 自定义参数,请输入 TOML 格式,将会直接覆盖当前界面内任何参数。实时更新,推荐写完后再粘贴过来"),
}).description("其他设置"),
PRECISION_CACHE_BATCH: Schema.object({
mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("bf16").description("训练混合精度, RTX30系列以后也可以指定`bf16`"),
full_fp16: Schema.boolean().description("完全使用 FP16 精度"),
full_bf16: Schema.boolean().description("完全使用 BF16 精度"),
fp8_base: Schema.boolean().default(true).description("对基础模型使用 FP8 精度"),
fp8_base_unet: Schema.boolean().description("仅对 U-Net 使用 FP8 精度CLIP-L不使用"),
no_half_vae: Schema.boolean().description("不使用半精度 VAE"),
sdpa: Schema.boolean().default(true).description("启用 sdpa"),
lowram: Schema.boolean().default(false).description("低内存模式 该模式下会将 U-net、文本编码器、VAE 直接加载到显存中"),
cache_latents: Schema.boolean().default(true).description("缓存图像 latent, 缓存 VAE 输出以减少 VRAM 使用"),
cache_latents_to_disk: Schema.boolean().default(true).description("缓存图像 latent 到磁盘"),
cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
persistent_data_loader_workers: Schema.boolean().default(true).description("保留加载训练集的worker减少每个 epoch 之间的停顿。"),
vae_batch_size: Schema.number().min(1).description("vae 编码批量大小"),
}).description("速度优化选项"),
DISTRIBUTED_TRAINING: Schema.object({
ddp_timeout: Schema.number().min(0).description("分布式训练超时时间"),
ddp_gradient_as_bucket_view: Schema.boolean(),

View File

@ -97,10 +97,14 @@ def validate_data_dir(path):
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warn(f"No subdir found in data dir")
log.warning(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) > 0:
log.info(f"Found {len(ok_dir)} legal dataset")
return True
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)