update train

pull/6/head
akiba 2023-01-23 11:26:14 +08:00
parent 25bf2a9b10
commit 29cc66928c
No known key found for this signature in database
GPG Key ID: 4CE9A24A15E72161
1 changed files with 14 additions and 2 deletions

View File

@ -1,5 +1,4 @@
# LoRA train script by @Akegarasu
$Env:HF_HOME = "huggingface"
# Train data path | 设置训练用模型、图片
$pretrained_model = "./sd-models/final-prune.ckpt" # base model path | 底模路径
@ -13,6 +12,8 @@ $max_train_epoches = 10 # max train epoches | 最大训练 epoch
$save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次
$network_dim = 32 # network dim
$clip_skip = 2
$train_unet_only = 1 # train U-Net only | 仅训练 U-Net
$train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器
# Learning rate | 学习率
$lr = "1e-4"
@ -24,6 +25,17 @@ $lr_scheduler = "cosine_with_restarts" # "linear", "cosine", "cosine_with_restar
$output_name = "aki" # output model name | 模型保存名称
$save_model_as = "safetensors" # model save ext | 模型保存格式
$Env:HF_HOME = "huggingface"
$ext_args = [System.Collections.ArrayList]::new()
if ($train_unet_only) {
[void]$ext_args.Add("--network_train_unet_only")
}
if ($train_text_encoder_only) {
[void]$ext_args.Add("--network_train_text_encoder_only")
}
# run train
accelerate launch --num_cpu_threads_per_process=8 "train_network.py" `
--enable_bucket `
@ -52,4 +64,4 @@ accelerate launch --num_cpu_threads_per_process=8 "train_network.py" `
--max_token_length=225 `
--caption_extension=".txt" `
--save_model_as=$save_model_as `
--xformers --shuffle_caption --use_8bit_adam
--xformers --shuffle_caption --use_8bit_adam $ext_args