lora-scripts/train.ipynb

100 lines
3.7 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Train data path | 设置训练用模型、图片\n",
"pretrained_model = \"./sd-models/model.ckpt\" # base model path | 底模路径\n",
"train_data_dir = \"./train/aki\" # train dataset path | 训练数据集路径\n",
"\n",
"# Train related params | 训练相关参数\n",
"resolution = \"512,512\" # image resolution w,h. 图片分辨率,宽,高。支持非正方形,但必须是 64 倍数。\n",
"batch_size = 1 # batch size\n",
"max_train_epoches = 10 # max train epoches | 最大训练 epoch\n",
"save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次\n",
"network_dim = 32 # network dim | 常用 4~128不是越大越好\n",
"network_alpha= 32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值,如 network_dim的一半 防止下溢。默认值为 1使用较小的 alpha 需要提升学习率。\n",
"clip_skip = 2 # clip skip | 玄学 一般用 2\n",
"train_unet_only = 0 # train U-Net only | 仅训练 U-Net开启这个会牺牲效果大幅减少显存使用。6G显存可以开启\n",
"train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器\n",
"\n",
"# Learning rate | 学习率\n",
"lr = \"1e-4\"\n",
"unet_lr = \"1e-4\"\n",
"text_encoder_lr = \"1e-5\"\n",
"lr_scheduler = \"cosine_with_restarts\" # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n",
"\n",
"# Output settings | 输出设置\n",
"output_name = \"aki\" # output model name | 模型保存名称\n",
"save_model_as = \"safetensors\" # model save ext | 模型保存格式 ckpt, pt, safetensors"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!accelerate launch --num_cpu_threads_per_process=8 \"./scripts/train_network.py\" \\\n",
" --enable_bucket \\\n",
" --pretrained_model_name_or_path=$pretrained_model \\\n",
" --train_data_dir=$train_data_dir \\\n",
" --output_dir=\"./output\" \\\n",
" --logging_dir=\"./logs\" \\\n",
" --resolution=$resolution \\\n",
" --network_module=networks.lora \\\n",
" --max_train_epochs=$max_train_epoches \\\n",
" --learning_rate=$lr \\\n",
" --unet_lr=$unet_lr \\\n",
" --text_encoder_lr=$text_encoder_lr \\\n",
" --network_dim=$network_dim \\\n",
" --network_alpha=$network_alpha \\\n",
" --output_name=$output_name \\\n",
" --lr_scheduler=$lr_scheduler \\\n",
" --train_batch_size=$batch_size \\\n",
" --save_every_n_epochs=$save_every_n_epochs \\\n",
" --mixed_precision=\"fp16\" \\\n",
" --save_precision=\"fp16\" \\\n",
" --seed=\"1337\" \\\n",
" --cache_latents \\\n",
" --clip_skip=$clip_skip \\\n",
" --prior_loss_weight=1 \\\n",
" --max_token_length=225 \\\n",
" --caption_extension=\".txt\" \\\n",
" --save_model_as=$save_model_as \\\n",
" --xformers --shuffle_caption --use_8bit_adam"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.7 (tags/v3.10.7:6cc6b13, Sep 5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "675b13e958f0d0236d13cdfe08a1df3882cae564fa23a2e7e5eb1f2c6c632b02"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}