diff --git a/train.ps1 b/train.ps1 index c166c4e..ba7313e 100644 --- a/train.ps1 +++ b/train.ps1 @@ -1,5 +1,4 @@ # LoRA train script by @Akegarasu -$Env:HF_HOME = "huggingface" # Train data path | 设置训练用模型、图片 $pretrained_model = "./sd-models/final-prune.ckpt" # base model path | 底模路径 @@ -13,6 +12,8 @@ $max_train_epoches = 10 # max train epoches | 最大训练 epoch $save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次 $network_dim = 32 # network dim $clip_skip = 2 +$train_unet_only = 1 # train U-Net only | 仅训练 U-Net +$train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器 # Learning rate | 学习率 $lr = "1e-4" @@ -24,6 +25,17 @@ $lr_scheduler = "cosine_with_restarts" # "linear", "cosine", "cosine_with_restar $output_name = "aki" # output model name | 模型保存名称 $save_model_as = "safetensors" # model save ext | 模型保存格式 +$Env:HF_HOME = "huggingface" +$ext_args = [System.Collections.ArrayList]::new() + +if ($train_unet_only) { + [void]$ext_args.Add("--network_train_unet_only") +} + +if ($train_text_encoder_only) { + [void]$ext_args.Add("--network_train_text_encoder_only") +} + # run train accelerate launch --num_cpu_threads_per_process=8 "train_network.py" ` --enable_bucket ` @@ -52,4 +64,4 @@ accelerate launch --num_cpu_threads_per_process=8 "train_network.py" ` --max_token_length=225 ` --caption_extension=".txt" ` --save_model_as=$save_model_as ` - --xformers --shuffle_caption --use_8bit_adam \ No newline at end of file + --xformers --shuffle_caption --use_8bit_adam $ext_args \ No newline at end of file