update train
parent
25bf2a9b10
commit
29cc66928c
16
train.ps1
16
train.ps1
|
|
@ -1,5 +1,4 @@
|
|||
# LoRA train script by @Akegarasu
|
||||
$Env:HF_HOME = "huggingface"
|
||||
|
||||
# Train data path | 设置训练用模型、图片
|
||||
$pretrained_model = "./sd-models/final-prune.ckpt" # base model path | 底模路径
|
||||
|
|
@ -13,6 +12,8 @@ $max_train_epoches = 10 # max train epoches | 最大训练 epoch
|
|||
$save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次
|
||||
$network_dim = 32 # network dim
|
||||
$clip_skip = 2
|
||||
$train_unet_only = 1 # train U-Net only | 仅训练 U-Net
|
||||
$train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器
|
||||
|
||||
# Learning rate | 学习率
|
||||
$lr = "1e-4"
|
||||
|
|
@ -24,6 +25,17 @@ $lr_scheduler = "cosine_with_restarts" # "linear", "cosine", "cosine_with_restar
|
|||
$output_name = "aki" # output model name | 模型保存名称
|
||||
$save_model_as = "safetensors" # model save ext | 模型保存格式
|
||||
|
||||
$Env:HF_HOME = "huggingface"
|
||||
$ext_args = [System.Collections.ArrayList]::new()
|
||||
|
||||
if ($train_unet_only) {
|
||||
[void]$ext_args.Add("--network_train_unet_only")
|
||||
}
|
||||
|
||||
if ($train_text_encoder_only) {
|
||||
[void]$ext_args.Add("--network_train_text_encoder_only")
|
||||
}
|
||||
|
||||
# run train
|
||||
accelerate launch --num_cpu_threads_per_process=8 "train_network.py" `
|
||||
--enable_bucket `
|
||||
|
|
@ -52,4 +64,4 @@ accelerate launch --num_cpu_threads_per_process=8 "train_network.py" `
|
|||
--max_token_length=225 `
|
||||
--caption_extension=".txt" `
|
||||
--save_model_as=$save_model_as `
|
||||
--xformers --shuffle_caption --use_8bit_adam
|
||||
--xformers --shuffle_caption --use_8bit_adam $ext_args
|
||||
Loading…
Reference in New Issue