update train scripts

pull/569/head
Akegarasu 2024-10-31 15:16:59 +08:00
parent 1836aadeba
commit 1df9685d09
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
6 changed files with 19 additions and 12 deletions

View File

@ -18,7 +18,7 @@ if ($v2) {
}
# run interrogate
accelerate launch --num_cpu_threads_per_process=8 "./scripts/networks/lora_interrogator.py" `
accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/networks/lora_interrogator.py" `
--sd_model=$sd_model `
--model=$model `
--batch_size=$batch_size `

View File

@ -31,7 +31,7 @@ if ($new_conv_rank) {
}
# run svd_merge
accelerate launch --num_cpu_threads_per_process=8 "./scripts/networks/svd_merge_lora.py" `
accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/networks/svd_merge_lora.py" `
--save_precision=$save_precision `
--precision=$precision `
--new_rank=$new_rank `

View File

@ -63,7 +63,7 @@ if ($frequency_tags) {
}
# run tagger
accelerate launch --num_cpu_threads_per_process=8 "./scripts/finetune/tag_images_by_wd14_tagger.py" `
accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/finetune/tag_images_by_wd14_tagger.py" `
$train_data_dir `
--thresh=$thresh `
--caption_extension .txt `

View File

@ -63,7 +63,7 @@ fi
# run tagger
accelerate launch --num_cpu_threads_per_process=8 "./scripts/finetune/tag_images_by_wd14_tagger.py" \
accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/finetune/tag_images_by_wd14_tagger.py" \
$train_data_dir \
--thresh=$thresh \
--caption_extension .txt \

View File

@ -2,7 +2,7 @@
# Train data path | 设置训练用模型、图片
$pretrained_model = "./sd-models/model.ckpt" # base model path | 底模路径
$model_type = "sd1.5" # sd1.5 sd2.0 sdxl model | 可选 sd1.5 sd2.0 sdxl。SD2.0模型 2.0模型下 clip_skip 默认无效
$model_type = "sd1.5" # sd1.5 sd2.0 sdxl flux model | 可选 sd1.5 sd2.0 sdxl flux。SD2.0模型下 clip_skip 默认无效
$parameterization = 0 # parameterization | 参数化 本参数需要在 model_type 为 sd2.0 时才可启用
$train_data_dir = "./train/aki" # train dataset path | 训练数据集路径
@ -75,14 +75,19 @@ $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
$ext_args = [System.Collections.ArrayList]::new()
$launch_args = [System.Collections.ArrayList]::new()
$trainer_file = "./scripts/train_network.py"
$trainer_file = "./scripts/stable/train_network.py"
if ($model_type -eq "sd1.5") {
[void]$ext_args.Add("--clip_skip=$clip_skip")
} elseif ($model_type -eq "sd2.0") {
}
elseif ($model_type -eq "sd2.0") {
[void]$ext_args.Add("--v2")
} elseif ($model_type -eq "sdxl") {
$trainer_file = "./scripts/sdxl_train_network.py"
}
elseif ($model_type -eq "sdxl") {
$trainer_file = "./scripts/stable/sdxl_train_network.py"
}
elseif ($model_type -eq "flux") {
$trainer_file = "./scripts/dev/flux_train_network.py"
}
if ($multi_gpu) {

View File

@ -3,7 +3,7 @@
# Train data path | 设置训练用模型、图片
pretrained_model="./sd-models/model.ckpt" # base model path | 底模路径
model_type="sd1.5" # option: sd1.5 sd2.0 sdxl | 可选 sd1.5 sd2.0 sdxl。SD2.0模型 2.0模型下 clip_skip 默认无效
model_type="sd1.5" # option: sd1.5 sd2.0 sdxl flux | 可选 sd1.5 sd2.0 sdxl flux。SD2.0模型下 clip_skip 默认无效
parameterization=0 # parameterization | 参数化 本参数需要在 model_type 为 sd2.0 时才可启用
train_data_dir="./train/aki" # train dataset path | 训练数据集路径
@ -74,14 +74,16 @@ export TF_CPP_MIN_LOG_LEVEL=3
extArgs=()
launchArgs=()
trainer_file="./scripts/train_network.py"
trainer_file="./scripts/stable/train_network.py"
if [ $model_type == "sd1.5" ]; then
ext_args+=("--clip_skip=$clip_skip")
elif [ $model_type == "sd2.0" ]; then
ext_args+=("--v2")
elif [ $model_type == "sdxl" ]; then
trainer_file="./scripts/sdxl_train_network.py"
trainer_file="./scripts/stable/sdxl_train_network.py"
elif [ $model_type == "flux" ]; then
trainer_file="./scripts/dev/flux_train_network.py"
fi
if [[ $multi_gpu == 1 ]]; then