update sd-scripts
parent
33a5ff3f6c
commit
8b86e38cb0
|
|
@ -0,0 +1 @@
|
|||
0778dd9b1df0d6aa33287ded3ce4195f3d03251b
|
||||
|
|
@ -36,6 +36,8 @@ Python 3.10.6およびGitが必要です。
|
|||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
|
||||
|
||||
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
||||
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
||||
|
||||
|
|
@ -45,7 +47,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
|||
|
||||
## Windows環境でのインストール
|
||||
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
|
||||
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
|
||||
|
|
@ -67,10 +69,12 @@ accelerate config
|
|||
|
||||
コマンドプロンプトでも同一です。
|
||||
|
||||
注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
|
||||
この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
||||
|
||||
PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
|
||||
|
||||
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
||||
|
||||
```txt
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ The command to install PyTorch is as follows:
|
|||
|
||||
### Recent Updates
|
||||
|
||||
Jan 25, 2025:
|
||||
|
||||
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
||||
- For details on how to set it up, please refer to the PR. The documentation will be updated as needed.
|
||||
- It will be added to other scripts as well.
|
||||
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used.
|
||||
|
||||
Dec 15, 2024:
|
||||
|
||||
- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
|
||||
|
|
@ -754,7 +761,7 @@ This repository contains the scripts for:
|
|||
|
||||
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
|
||||
|
||||
The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work.
|
||||
The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
|
||||
|
||||
## Links to usage documentation
|
||||
|
||||
|
|
@ -781,6 +788,8 @@ Python 3.10.6 and Git:
|
|||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
|
||||
|
||||
Give unrestricted script access to powershell so venv can work:
|
||||
|
||||
- Open an administrator powershell window
|
||||
|
|
@ -807,10 +816,12 @@ accelerate config
|
|||
|
||||
If `python -m venv` shows only `python`, change `python` to `py`.
|
||||
|
||||
__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
||||
Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
||||
|
||||
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
|
||||
|
||||
If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
|
||||
|
||||
<!--
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
|
|
@ -871,12 +882,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||
|
||||
## Change History
|
||||
|
||||
### Working in progress
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
|
||||
- If you encounter any issues, please report them.
|
||||
|
||||
- The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
|
||||
- The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
|
||||
- The following changes are included.
|
||||
|
||||
#### Changes
|
||||
|
||||
- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
|
||||
- Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
|
||||
|
||||
|
|
@ -917,7 +934,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
||||
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
||||
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1393
|
||||
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
|
||||
|
||||
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
|
||||
|
|
@ -987,6 +1003,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
|
|||
|
||||
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
|
||||
|
||||
#### 変更点
|
||||
|
||||
- devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきます。
|
||||
- マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください。
|
||||
- 以下の変更が含まれます。
|
||||
|
||||
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
|
||||
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
|
||||
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。
|
||||
|
|
|
|||
|
|
@ -91,9 +91,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -138,9 +138,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -126,9 +126,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
|
@ -36,8 +36,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
|
|
@ -80,6 +80,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||
args.blocks_to_swap = 18 # 18 is safe for most cases
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
|
@ -339,6 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
|
@ -375,7 +378,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
|
|
@ -420,7 +423,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
return model_pred
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ class BaseSubsetParams:
|
|||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -102,6 +104,8 @@ class BaseDatasetParams:
|
|||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
|||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
|
|
@ -234,6 +237,8 @@ class ConfigSanitizer:
|
|||
"enable_bucket": bool,
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
}
|
||||
|
|
@ -462,119 +467,136 @@ class BlueprintGenerator:
|
|||
|
||||
return default_value
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
extra_dataset_params = {}
|
||||
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": True}
|
||||
else:
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
|
||||
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
|
||||
continue
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
# if the dataset isn't setting a validation split, there is no current validation dataset
|
||||
if dataset_blueprint.params.validation_split == 0.0:
|
||||
continue
|
||||
|
||||
extra_dataset_params = {}
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": False}
|
||||
else:
|
||||
info += "\n"
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_separator: {subset.caption_separator}
|
||||
secondary_separator: {subset.secondary_separator}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min}
|
||||
token_warmup_step: {subset.token_warmup_step}
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
val_datasets.append(dataset)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
def print_info(_datasets, dataset_type: str):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
logger.info(f"{info}")
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of {dataset_type} {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
|
||||
logger.info(info)
|
||||
|
||||
print_info(datasets, "Dataset")
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
print_info(val_datasets, "Validation Dataset")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
logger.info(f"[Prepare dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
logger.info(f"[Prepare validation dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
|
||||
|
|
@ -63,7 +65,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
|||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
|
|
@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
|||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
def get_snr_scale(timesteps, noise_scheduler):
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
|
|
@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler):
|
|||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
if v_prediction:
|
||||
|
|
@ -453,7 +455,7 @@ def get_weighted_text_embeddings(
|
|||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
|
|
@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
|||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
|
|
@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
|||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
if "conditioning_images" in batch:
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
|||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pathlib
|
|||
import re
|
||||
import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -145,6 +146,37 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
|
||||
Shuffle the dataset based on the validation_seed or the current random seed.
|
||||
For example if the split of 0.2 of 100 images.
|
||||
[0:80] = 80 training images
|
||||
[80:] = 20 validation images
|
||||
"""
|
||||
if validation_seed is not None:
|
||||
logging.info(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
# Split the dataset between training and validation
|
||||
if is_training_dataset:
|
||||
# Training dataset we split to the first part
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
# Validation dataset we split to the second part
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
|
|
@ -397,6 +429,8 @@ class BaseSubset:
|
|||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
|
|
@ -424,6 +458,9 @@ class BaseSubset:
|
|||
|
||||
self.img_count = 0
|
||||
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
|
|
@ -453,6 +490,8 @@ class DreamBoothSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
|
|
@ -478,6 +517,8 @@ class DreamBoothSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
|
|
@ -518,6 +559,8 @@ class FineTuningSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
|
|
@ -543,6 +586,8 @@ class FineTuningSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
|
|
@ -579,6 +624,8 @@ class ControlNetSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
|
|
@ -604,6 +651,8 @@ class ControlNetSubset(BaseSubset):
|
|||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
|
|
@ -1786,9 +1835,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
is_training_dataset: bool,
|
||||
batch_size: int,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
|
|
@ -1799,6 +1852,8 @@ class DreamBoothDataset(BaseDataset):
|
|||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
|
|
@ -1808,6 +1863,9 @@ class DreamBoothDataset(BaseDataset):
|
|||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.is_training_dataset = is_training_dataset
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
|
|
@ -1915,6 +1973,30 @@ class DreamBoothDataset(BaseDataset):
|
|||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
#
|
||||
# We split the dataset for the subset based on if we are doing a validation split
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# if self.is_training_dataset is True -> training dataset
|
||||
# if self.is_training_dataset is False -> validation dataset
|
||||
if self.validation_split > 0.0:
|
||||
# For regularization images we do not want to split this dataset.
|
||||
if subset.is_reg is True:
|
||||
# Skip any validation dataset for regularization images
|
||||
if self.is_training_dataset is False:
|
||||
img_paths = []
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths = split_train_val(
|
||||
img_paths,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
|
|
@ -1973,9 +2055,10 @@ class DreamBoothDataset(BaseDataset):
|
|||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
num_repeats = subset.num_repeats if self.is_training_dataset else 1
|
||||
if num_repeats < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -1993,12 +2076,12 @@ class DreamBoothDataset(BaseDataset):
|
|||
continue
|
||||
|
||||
if subset.is_reg:
|
||||
num_reg_images += subset.num_repeats * len(img_paths)
|
||||
num_reg_images += num_repeats * len(img_paths)
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
|
|
@ -2009,10 +2092,12 @@ class DreamBoothDataset(BaseDataset):
|
|||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
logger.info(f"{num_train_images} train images with repeating.")
|
||||
images_split_name = "train" if self.is_training_dataset else "validation"
|
||||
logger.info(f"{num_train_images} {images_split_name} images with repeats.")
|
||||
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
logger.info(f"{num_reg_images} reg images with repeats.")
|
||||
if num_train_images < num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
|
|
@ -2050,6 +2135,8 @@ class FineTuningDataset(BaseDataset):
|
|||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
|
|
@ -2275,7 +2362,9 @@ class ControlNetDataset(BaseDataset):
|
|||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: float,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
|
|
@ -2314,6 +2403,7 @@ class ControlNetDataset(BaseDataset):
|
|||
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
True,
|
||||
batch_size,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
|
|
@ -2324,13 +2414,17 @@ class ControlNetDataset(BaseDataset):
|
|||
bucket_no_upscale,
|
||||
1.0,
|
||||
debug_dataset,
|
||||
validation_split,
|
||||
validation_seed,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||
self.batch_size = batch_size
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
|
|
@ -2800,6 +2894,9 @@ class MinimalDataset(BaseDataset):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resolutions(self) -> List[Tuple[int, int]]:
|
||||
return []
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
|
|
@ -4544,7 +4641,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = os.path.splitext(args.config_file)[0]
|
||||
logger.info(args.config_file)
|
||||
|
||||
return args
|
||||
|
||||
|
|
@ -4887,7 +4983,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
|||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
|
||||
if optimizer_type == "RAdamScheduleFree".lower():
|
||||
optimizer_class = sf.RAdamScheduleFree
|
||||
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
|
||||
|
|
@ -5838,13 +5934,13 @@ def save_sd_model_on_train_end_common(
|
|||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
def get_timesteps(min_timestep, max_timestep, b_size, device):
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
|
|
@ -5905,11 +6001,16 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
|
|||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""
|
||||
NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||
"""
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
|
|
@ -5917,6 +6018,8 @@ def conditional_loss(
|
|||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
|
|
@ -6329,6 +6432,30 @@ def sample_image_inference(
|
|||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
"""
|
||||
Initialize experiment trackers with tracker specific behaviors
|
||||
"""
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
default_tracker_name if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||
|
||||
# Define specific metrics to handle validation and epochs "steps"
|
||||
wandb_tracker.define_metric("epoch", hidden=True)
|
||||
wandb_tracker.define_metric("val_step", hidden=True)
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
|
@ -6397,4 +6524,7 @@ class LossRecorder:
|
|||
|
||||
@property
|
||||
def moving_average(self) -> float:
|
||||
return self.loss_total / len(self.loss_list)
|
||||
losses = len(self.loss_list)
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ voluptuous==0.13.1
|
|||
huggingface-hub==0.24.5
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
numpy<=2.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
|
|
|
|||
|
|
@ -149,9 +149,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
|
@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
|
|
@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
self.resolutions = train_dataset_group.get_resolutions()
|
||||
resolutions = train_dataset_group.get_resolutions()
|
||||
if val_dataset_group is not None:
|
||||
resolutions = resolutions + val_dataset_group.get_resolutions()
|
||||
self.resolutions = resolutions
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
|
@ -312,6 +317,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
|
@ -339,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
|
|
|
|||
|
|
@ -176,9 +176,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import argparse
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
|
@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
|
|
@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex
|
||||
|
||||
|
|
@ -18,11 +19,13 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
|
|
|
|||
|
|
@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
|
|
|
|||
|
|
@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -89,9 +89,10 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -2,17 +2,19 @@ import importlib
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
import typing
|
||||
from typing import Any, List, Union, Optional
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
|
@ -20,6 +22,7 @@ init_ipex()
|
|||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
import library.train_util as train_util
|
||||
|
|
@ -114,15 +117,17 @@ class NetworkTrainer:
|
|||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
):
|
||||
):
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
|
@ -196,10 +201,10 @@ class NetworkTrainer:
|
|||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(images).latent_dist.sample()
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return latents * self.vae_scale_factor
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
|
|
@ -214,6 +219,7 @@ class NetworkTrainer:
|
|||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
|
|
@ -227,7 +233,7 @@ class NetworkTrainer:
|
|||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
|
|
@ -271,7 +277,7 @@ class NetworkTrainer:
|
|||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
|
|
@ -308,6 +314,107 @@ class NetworkTrainer:
|
|||
|
||||
# endregion
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
|
|
@ -373,10 +480,11 @@ class NetworkTrainer:
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
@ -384,8 +492,12 @@ class NetworkTrainer:
|
|||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(val_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
|
|
@ -397,8 +509,12 @@ class NetworkTrainer:
|
|||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
|
|
@ -444,6 +560,8 @@ class NetworkTrainer:
|
|||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
|
@ -459,6 +577,8 @@ class NetworkTrainer:
|
|||
if text_encoder_outputs_caching_strategy is not None:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
|
|
@ -567,6 +687,8 @@ class NetworkTrainer:
|
|||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
|
@ -579,6 +701,15 @@ class NetworkTrainer:
|
|||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
|
|
@ -654,8 +785,8 @@ class NetworkTrainer:
|
|||
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = ds_model
|
||||
else:
|
||||
|
|
@ -676,8 +807,8 @@ class NetworkTrainer:
|
|||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = network
|
||||
|
||||
|
|
@ -769,6 +900,7 @@ class NetworkTrainer:
|
|||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
|
|
@ -788,6 +920,7 @@ class NetworkTrainer:
|
|||
"ss_text_encoder_lr": text_encoder_lr,
|
||||
"ss_unet_lr": args.unet_lr,
|
||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
|
||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||
"ss_num_epochs": num_train_epochs,
|
||||
|
|
@ -835,6 +968,11 @@ class NetworkTrainer:
|
|||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
|
|
@ -1051,20 +1189,15 @@ class NetworkTrainer:
|
|||
|
||||
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
train_util.init_trackers(accelerator, args, "network_train")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
|
|
@ -1109,10 +1242,17 @@ class NetworkTrainer:
|
|||
optimizer_eval_fn()
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
is_tracking = len(accelerator.trackers) > 0
|
||||
if is_tracking:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader))
|
||||
if args.max_validation_steps is not None
|
||||
else len(val_dataloader)
|
||||
)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
|
|
@ -1132,13 +1272,14 @@ class NetworkTrainer:
|
|||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
# TRAINING
|
||||
skipped_dataloader = None
|
||||
if initial_step > 0:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
|
|
@ -1156,98 +1297,24 @@ class NetworkTrainer:
|
|||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
# get multiplier for each sample
|
||||
if network_has_multiplier:
|
||||
multipliers = batch["network_multipliers"]
|
||||
# if all multipliers are same, use single multiplier
|
||||
if torch.all(multipliers == multipliers[0]):
|
||||
multipliers = multipliers[0].item()
|
||||
else:
|
||||
raise NotImplementedError("multipliers for each sample is not supported yet")
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
|
|
@ -1302,19 +1369,148 @@ class NetworkTrainer:
|
|||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
)
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps"
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"epoch": epoch + 1
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
|
|
@ -1496,9 +1692,36 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
parser.add_argument(
|
||||
"--validation_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional, Union
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
@ -99,9 +99,12 @@ class TextualInversionTrainer:
|
|||
self.vae_scale_factor = 0.18215
|
||||
self.is_sdxl = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
|
||||
|
|
@ -320,11 +323,12 @@ class TextualInversionTrainer:
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
self.assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def train(args):
|
|||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
|
|||
|
||||
[README in English](./README.md) ←更新情報はこちらにあります
|
||||
|
||||
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
|
||||
|
||||
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
|
||||
以下のスクリプトがあります。
|
||||
|
|
@ -32,6 +36,8 @@ Python 3.10.6およびGitが必要です。
|
|||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
|
||||
|
||||
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
||||
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
||||
|
||||
|
|
@ -41,7 +47,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
|||
|
||||
## Windows環境でのインストール
|
||||
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
|
||||
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
|
||||
|
|
@ -63,10 +69,12 @@ accelerate config
|
|||
|
||||
コマンドプロンプトでも同一です。
|
||||
|
||||
注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
|
||||
この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
||||
|
||||
PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
|
||||
|
||||
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
||||
|
||||
```txt
|
||||
|
|
|
|||
|
|
@ -5,6 +5,11 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
|
||||
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
|
||||
|
||||
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
This repository contains the scripts for:
|
||||
|
|
@ -20,7 +25,7 @@ This repository contains the scripts for:
|
|||
|
||||
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
|
||||
|
||||
The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work.
|
||||
The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
|
||||
|
||||
## Links to usage documentation
|
||||
|
||||
|
|
@ -47,6 +52,8 @@ Python 3.10.6 and Git:
|
|||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
|
||||
|
||||
Give unrestricted script access to powershell so venv can work:
|
||||
|
||||
- Open an administrator powershell window
|
||||
|
|
@ -73,10 +80,12 @@ accelerate config
|
|||
|
||||
If `python -m venv` shows only `python`, change `python` to `py`.
|
||||
|
||||
__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
||||
Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
||||
|
||||
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
|
||||
|
||||
If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
|
||||
|
||||
<!--
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
|
|
@ -137,6 +146,212 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||
|
||||
## Change History
|
||||
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
|
||||
- If you encounter any issues, please report them.
|
||||
|
||||
- The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
|
||||
- The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
|
||||
- The following changes are included.
|
||||
|
||||
#### Changes
|
||||
|
||||
- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
|
||||
- Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
|
||||
|
||||
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
|
||||
|
||||
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
|
||||
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
|
||||
|
||||
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
|
||||
|
||||
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
||||
|
||||
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
||||
1. Optimization of Calculation Order:
|
||||
- Changed the calculation order in the forward method from (Wx)R to W(xR).
|
||||
- This has improved computational efficiency and processing speed.
|
||||
2. Correction of Bias Application:
|
||||
- In the previous implementation, R was incorrectly applied to the bias.
|
||||
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
|
||||
3. Efficiency Enhancement in Matrix Operations:
|
||||
- Introduced einsum in both the forward and merge_to methods.
|
||||
- This has optimized matrix operations, resulting in further speed improvements.
|
||||
4. Proper Handling of Data Types:
|
||||
- Improved to use torch.float32 during calculations and convert results back to the original data type.
|
||||
- This maintains precision while ensuring compatibility with the original model.
|
||||
5. Unified Processing for Conv2d and Linear Layers:
|
||||
- Implemented a consistent method for applying OFT to both layer types.
|
||||
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
|
||||
|
||||
- Additional Information
|
||||
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
|
||||
|
||||
* Performance Improvement: Training speed has been improved by approximately 30%.
|
||||
|
||||
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
|
||||
|
||||
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
|
||||
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
||||
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
||||
|
||||
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
|
||||
|
||||
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
|
||||
|
||||
- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
|
||||
|
||||
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
|
||||
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
|
||||
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
|
||||
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
|
||||
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
|
||||
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
|
||||
- Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
|
||||
|
||||
- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
||||
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
|
||||
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
|
||||
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
|
||||
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
|
||||
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
|
||||
|
||||
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
|
||||
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
|
||||
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
|
||||
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
|
||||
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
|
||||
- `network_module` `networks.lora` and `networks.dylora` are available.
|
||||
|
||||
- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru!
|
||||
- The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file.
|
||||
- See [About masked loss](./docs/masked_loss_README.md) for details.
|
||||
|
||||
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
||||
- Specify the learning rate and dim (rank) for each block.
|
||||
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
|
||||
|
||||
- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath!
|
||||
- The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended.
|
||||
- When specifying from the command line, use `=` like `--learning_rate=-1e-7`.
|
||||
|
||||
- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
|
||||
- Some settings, such as API keys and directory specifications, are not output due to security issues.
|
||||
|
||||
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
|
||||
|
||||
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
|
||||
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
|
||||
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
|
||||
- If `--resume` is specified, the step saved in the state is used.
|
||||
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
|
||||
|
||||
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
|
||||
- It seems that the model file loading is faster in the WSL environment etc.
|
||||
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
|
||||
|
||||
- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath!
|
||||
|
||||
- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
|
||||
|
||||
- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO!
|
||||
|
||||
- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th!
|
||||
|
||||
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
||||
|
||||
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
|
||||
|
||||
#### 変更点
|
||||
|
||||
- devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきます。
|
||||
- マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください。
|
||||
- 以下の変更が含まれます。
|
||||
|
||||
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
|
||||
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
|
||||
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
|
||||
- mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
|
||||
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
|
||||
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
|
||||
- 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
|
||||
|
||||
- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
||||
- Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
|
||||
- `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
|
||||
- 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。
|
||||
- `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
|
||||
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
|
||||
|
||||
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
|
||||
- LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
|
||||
- `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
|
||||
- `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
|
||||
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
|
||||
- `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
|
||||
|
||||
- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。
|
||||
- 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。
|
||||
- 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。
|
||||
|
||||
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
||||
- ブロックごとに学習率および dim (rank) を指定することができます。
|
||||
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
|
||||
|
||||
- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。
|
||||
- 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。
|
||||
- コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。
|
||||
|
||||
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
|
||||
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
|
||||
|
||||
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
|
||||
|
||||
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
|
||||
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
|
||||
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
|
||||
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
|
||||
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
|
||||
|
||||
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
|
||||
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
|
||||
- `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
|
||||
|
||||
- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。
|
||||
|
||||
- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
||||
https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。
|
||||
|
||||
- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。
|
||||
|
||||
- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。
|
||||
|
||||
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
||||
|
||||
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
|
||||
|
||||
|
||||
### Oct 27, 2024 / 2024-10-27:
|
||||
|
||||
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
|
||||
- This will be included in the next release.
|
||||
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
|
||||
- これは次回リリースに含まれます。
|
||||
|
||||
### Oct 26, 2024 / 2024-10-26:
|
||||
|
||||
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
|
||||
- It will be included in the next release.
|
||||
|
||||
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
|
||||
- 以上は次回リリースに含まれます。
|
||||
|
||||
### Sep 13, 2024 / 2024-09-13:
|
||||
|
||||
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
||||
|
||||
[default.extend-identifiers]
|
||||
ddPn08="ddPn08"
|
||||
|
||||
[default.extend-words]
|
||||
NIN="NIN"
|
||||
|
|
@ -27,6 +28,7 @@ rik="rik"
|
|||
koo="koo"
|
||||
yos="yos"
|
||||
wn="wn"
|
||||
hime="hime"
|
||||
|
||||
|
||||
[files]
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ def train(args):
|
|||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
|
@ -310,7 +312,11 @@ def train(args):
|
|||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
|
@ -354,7 +360,9 @@ def train(args):
|
|||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
|
|
@ -368,7 +376,9 @@ def train(args):
|
|||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
|
|
@ -376,11 +386,13 @@ def train(args):
|
|||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
|
|
@ -471,7 +483,7 @@ def train(args):
|
|||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import cv2
|
|||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
|
|
@ -18,8 +19,10 @@ from torchvision import transforms
|
|||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
|
|
@ -89,7 +92,9 @@ def main(args):
|
|||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
assert (
|
||||
len(max_reso) == 2
|
||||
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
|
|
@ -107,7 +112,7 @@ def main(args):
|
|||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
|
|
@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
|
|
@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
"--bucket_no_upscale",
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="use mixed precision / 混合精度を使う場合、その精度",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
|
|
@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
"--flip_aug",
|
||||
action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mask",
|
||||
type=str,
|
||||
default="",
|
||||
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from PIL import Image
|
|||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, pil_resize
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
|
@ -42,8 +42,10 @@ def preprocess_image(image):
|
|||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
if size > IMAGE_SIZE:
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
|
||||
else:
|
||||
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
|
|
@ -112,7 +114,6 @@ def main(args):
|
|||
|
||||
# モデルを読み込む
|
||||
if args.onnx:
|
||||
import torch
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||
"""
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
logger.info("Enable memory efficient attention for U-Net")
|
||||
|
||||
|
|
@ -1435,6 +1436,7 @@ class BatchDataBase(NamedTuple):
|
|||
clip_prompt: str
|
||||
guide_image: Any
|
||||
raw_prompt: str
|
||||
file_name: Optional[str]
|
||||
|
||||
|
||||
class BatchDataExt(NamedTuple):
|
||||
|
|
@ -1493,8 +1495,6 @@ def main(args):
|
|||
highres_fix = args.highres_fix_scale is not None
|
||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
|
@ -2316,7 +2316,7 @@ def main(args):
|
|||
# このバッチの情報を取り出す
|
||||
(
|
||||
return_latents,
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
|
||||
(
|
||||
width,
|
||||
height,
|
||||
|
|
@ -2339,6 +2339,7 @@ def main(args):
|
|||
prompts = []
|
||||
negative_prompts = []
|
||||
raw_prompts = []
|
||||
filenames = []
|
||||
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
noises = [
|
||||
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
|
|
@ -2371,7 +2372,7 @@ def main(args):
|
|||
all_guide_images_are_same = True
|
||||
for i, (
|
||||
_,
|
||||
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
|
||||
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
|
||||
_,
|
||||
) in enumerate(batch):
|
||||
prompts.append(prompt)
|
||||
|
|
@ -2379,6 +2380,7 @@ def main(args):
|
|||
seeds.append(seed)
|
||||
clip_prompts.append(clip_prompt)
|
||||
raw_prompts.append(raw_prompt)
|
||||
filenames.append(filename)
|
||||
|
||||
if init_image is not None:
|
||||
init_images.append(init_image)
|
||||
|
|
@ -2478,8 +2480,8 @@ def main(args):
|
|||
# save image
|
||||
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
|
||||
):
|
||||
if highres_fix:
|
||||
seed -= 1 # record original seed
|
||||
|
|
@ -2505,17 +2507,23 @@ def main(args):
|
|||
metadata.add_text("crop-top", str(crop_top))
|
||||
metadata.add_text("crop-left", str(crop_left))
|
||||
|
||||
if args.use_original_file_name and init_images is not None:
|
||||
if type(init_images) is list:
|
||||
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
|
||||
else:
|
||||
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
|
||||
elif args.sequential_file_name:
|
||||
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
|
||||
if filename is not None:
|
||||
fln = filename
|
||||
else:
|
||||
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
|
||||
if args.use_original_file_name and init_images is not None:
|
||||
if type(init_images) is list:
|
||||
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
|
||||
else:
|
||||
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
|
||||
elif args.sequential_file_name:
|
||||
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
|
||||
else:
|
||||
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
|
||||
|
||||
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
|
||||
if fln.endswith(".webp"):
|
||||
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
|
||||
else:
|
||||
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
|
||||
|
||||
if not args.no_preview and not highres_1st and args.interactive:
|
||||
try:
|
||||
|
|
@ -2562,6 +2570,7 @@ def main(args):
|
|||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
|
||||
filename = None
|
||||
|
||||
if pi == 0 or len(raw_prompts) > 1:
|
||||
# parse prompt: if prompt is not changed, skip parsing
|
||||
|
|
@ -2783,6 +2792,12 @@ def main(args):
|
|||
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||
continue
|
||||
|
||||
m = re.match(r"f (.+)", parg, re.IGNORECASE)
|
||||
if m: # filename
|
||||
filename = m.group(1)
|
||||
logger.info(f"filename: {filename}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(f"{ex}")
|
||||
|
|
@ -2873,7 +2888,16 @@ def main(args):
|
|||
b1 = BatchData(
|
||||
False,
|
||||
BatchDataBase(
|
||||
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
|
||||
global_step,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
seed,
|
||||
init_image,
|
||||
mask_image,
|
||||
clip_prompt,
|
||||
guide_image,
|
||||
raw_prompt,
|
||||
filename,
|
||||
),
|
||||
BatchDataExt(
|
||||
width,
|
||||
|
|
@ -2916,7 +2940,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2216,8 +2216,6 @@ def main(args):
|
|||
highres_fix = args.highres_fix_scale is not None
|
||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
import math
|
||||
import torch
|
||||
from transformers import Adafactor
|
||||
|
||||
@torch.no_grad()
|
||||
def adafactor_step_param(self, p, group):
|
||||
if p.grad is None:
|
||||
return
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = Adafactor._rms(p_data_fp32)
|
||||
lr = Adafactor._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
update = (grad ** 2) + group["eps"][0]
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def adafactor_step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
adafactor_step_param(self, p, group)
|
||||
|
||||
return loss
|
||||
|
||||
def patch_adafactor_fused(optimizer: Adafactor):
|
||||
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
||||
optimizer.step = adafactor_step.__get__(optimizer)
|
||||
|
|
@ -86,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
|||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -191,6 +193,7 @@ class ConfigSanitizer:
|
|||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"secondary_separator": str,
|
||||
"caption_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
|
|
@ -212,11 +215,13 @@ class ConfigSanitizer:
|
|||
DB_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
"is_reg": bool,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
# FT means FineTuning
|
||||
FT_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("metadata_file"): str,
|
||||
"image_dir": str,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
|
|
@ -523,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_separator: {subset.caption_separator}
|
||||
secondary_separator: {subset.secondary_separator}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
|
|
@ -536,6 +542,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask},
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
|
|
|
|||
|
|
@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
|||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
|
@ -480,12 +483,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
|||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
if "conditioning_images" in batch:
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
# print(f"conditioning_image: {mask_image.shape}")
|
||||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||||
# alpha mask is 0 to 1
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||||
else:
|
||||
return loss
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from functools import cache
|
|||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import safetensors
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
|
@ -8,8 +9,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionMode
|
|||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
|
|
@ -163,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
|||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
||||
# model_version is reserved for future use
|
||||
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
try:
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
except:
|
||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||
if disable_mmap:
|
||||
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
except:
|
||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||
epoch = None
|
||||
global_step = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
|
@ -13,8 +14,10 @@ from transformers import CLIPTokenizer
|
|||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
|
|
@ -44,6 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
model_dtype,
|
||||
args.disable_mmap_load_safetensors,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
|
|
@ -60,7 +64,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||
|
||||
|
||||
def _load_target_model(
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
||||
):
|
||||
# model_dtype only work with full fp16/bf16
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
|
|
@ -75,7 +79,7 @@ def _load_target_model(
|
|||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
|
@ -332,6 +336,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
|||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,6 +7,9 @@ from typing import *
|
|||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
|
|
@ -79,6 +82,24 @@ def setup_logging(args=None, log_level=None, reset=False):
|
|||
logger.info(msg_init)
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
||||
else:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
||||
|
||||
return resized_cv2
|
||||
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def main(file):
|
|||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA modules: {len(values)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ from typing import Optional, List, Type
|
|||
import torch
|
||||
from library import sdxl_original_unet
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
|
|
@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
|
|||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
|
||||
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
|
|
@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
|
|||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
cx = self.lllite_conditioning1(self.cond_image)
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1)
|
||||
cx = self.mid(cx)
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@ from transformers import CLIPTextModel
|
|||
import torch
|
||||
from torch import nn
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
|
|
@ -195,7 +198,7 @@ def create_network(
|
|||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
|
|
@ -211,6 +214,16 @@ def create_network(
|
|||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
|
|
@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
|
|||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
|
|
@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module):
|
|||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
|
|
@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module):
|
|||
else:
|
||||
index = None
|
||||
logger.info("create LoRA for Text Encoder")
|
||||
|
||||
|
||||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
|
||||
|
|
@ -346,6 +363,14 @@ class DyLoRANetwork(torch.nn.Module):
|
|||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
|
|
@ -406,27 +431,53 @@ class DyLoRANetwork(torch.nn.Module):
|
|||
logger.info(f"weights are merged")
|
||||
"""
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_B" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
|
@ -385,14 +386,14 @@ class LoRAInfModule(LoRAModule):
|
|||
return out
|
||||
|
||||
|
||||
def parse_block_lr_kwargs(nw_kwargs):
|
||||
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
|
||||
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
||||
|
||||
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
||||
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
|
|
@ -401,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
return get_block_lr_weight(
|
||||
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
|
|
@ -424,6 +423,9 @@ def create_network(
|
|||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
||||
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
|
|
@ -441,21 +443,21 @@ def create_network(
|
|||
|
||||
# block dim/alpha/lr
|
||||
block_dims = kwargs.get("block_dims", None)
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
||||
|
||||
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||
if block_dims is not None or block_lr_weight is not None:
|
||||
block_alphas = kwargs.get("block_alphas", None)
|
||||
conv_block_dims = kwargs.get("conv_block_dims", None)
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
||||
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
)
|
||||
|
||||
# remove block dim/alpha without learning rate
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -488,10 +490,20 @@ def create_network(
|
|||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
varbose=True,
|
||||
is_sdxl=is_sdxl,
|
||||
)
|
||||
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
if block_lr_weight is not None:
|
||||
network.set_block_lr_weight(block_lr_weight)
|
||||
|
||||
return network
|
||||
|
||||
|
|
@ -501,9 +513,13 @@ def create_network(
|
|||
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
||||
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
||||
def get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
):
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
||||
if not is_sdxl:
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
|
||||
else:
|
||||
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
|
||||
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
||||
|
||||
def parse_ints(s):
|
||||
return [int(i) for i in s.split(",")]
|
||||
|
|
@ -514,9 +530,10 @@ def get_block_dims_and_alphas(
|
|||
# block_dimsとblock_alphasをパースする。必ず値が入る
|
||||
if block_dims is not None:
|
||||
block_dims = parse_ints(block_dims)
|
||||
assert (
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
assert len(block_dims) == num_total_blocks, (
|
||||
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
|
||||
+ f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
|
|
@ -567,15 +584,25 @@ def get_block_dims_and_alphas(
|
|||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
|
||||
# 戻り値は block ごとの倍率のリスト
|
||||
def get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
is_sdxl,
|
||||
down_lr_weight: Union[str, List[float]],
|
||||
mid_lr_weight: List[float],
|
||||
up_lr_weight: Union[str, List[float]],
|
||||
zero_threshold: float,
|
||||
) -> Optional[List[float]]:
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
if not is_sdxl:
|
||||
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
|
||||
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
|
||||
else:
|
||||
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
|
||||
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
|
||||
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
|
|
@ -585,15 +612,18 @@ def get_block_lr_weight(
|
|||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
return [
|
||||
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
|
||||
for i in reversed(range(max_len_for_down_or_up))
|
||||
]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
return [0.0 + base_lr] * max_len_for_down_or_up
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
|
|
@ -606,20 +636,36 @@ def get_block_lr_weight(
|
|||
if type(up_lr_weight) == str:
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
|
||||
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
|
||||
):
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
|
||||
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
|
||||
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
|
||||
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
|
||||
logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
|
||||
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
|
||||
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
|
||||
):
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
|
||||
logger.warning(
|
||||
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
|
||||
)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
|
||||
|
||||
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
|
||||
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
|
||||
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
|
||||
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||
|
|
@ -627,78 +673,139 @@ def get_block_lr_weight(
|
|||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||
else:
|
||||
down_lr_weight = [1.0] * max_len_for_down_or_up
|
||||
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
|
||||
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||
else:
|
||||
logger.info("mid_lr_weight: 1.0")
|
||||
mid_lr_weight = [1.0] * max_len_for_mid
|
||||
logger.info("mid_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||
else:
|
||||
up_lr_weight = [1.0] * max_len_for_down_or_up
|
||||
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
|
||||
|
||||
if is_sdxl:
|
||||
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
|
||||
|
||||
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
|
||||
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
||||
), f"lr_weight length is invalid: {len(lr_weight)}"
|
||||
|
||||
return lr_weight
|
||||
|
||||
|
||||
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
||||
def remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
|
||||
):
|
||||
# set 0 to block dim without learning rate to remove the block
|
||||
if down_lr_weight != None:
|
||||
for i, lr in enumerate(down_lr_weight):
|
||||
if block_lr_weight is not None:
|
||||
for i, lr in enumerate(block_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[i] = 0
|
||||
if mid_lr_weight != None:
|
||||
if mid_lr_weight == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if up_lr_weight != None:
|
||||
for i, lr in enumerate(up_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 外部から呼び出す可能性を考慮しておく
|
||||
def get_block_index(lora_name: str) -> int:
|
||||
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
if not is_sdxl:
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
else:
|
||||
# copy from sdxl_train
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
name = lora_name[len("lora_unet_") :]
|
||||
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
|
||||
block_idx = 0 # 0
|
||||
elif name.startswith("input_blocks_"): # 1-9
|
||||
block_idx = 1 + int(name.split("_")[2])
|
||||
elif name.startswith("middle_block_"): # 10-12
|
||||
block_idx = 10 + int(name.split("_")[2])
|
||||
elif name.startswith("output_blocks_"): # 13-21
|
||||
block_idx = 13 + int(name.split("_")[2])
|
||||
elif name.startswith("out_"): # 22, out, no LoRA
|
||||
block_idx = 22
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
def convert_diffusers_to_sai_if_needed(weights_sd):
|
||||
# only supports U-Net LoRA modules
|
||||
|
||||
found_up_down_blocks = False
|
||||
for k in list(weights_sd.keys()):
|
||||
if "down_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if "up_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if not found_up_down_blocks:
|
||||
return
|
||||
|
||||
from library.sdxl_model_util import make_unet_conversion_map
|
||||
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||
|
||||
# # add extra conversion
|
||||
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
|
||||
|
||||
logger.info(f"Converting LoRA keys from Diffusers to SAI")
|
||||
lora_unet_prefix = "lora_unet_"
|
||||
for k in list(weights_sd.keys()):
|
||||
if not k.startswith(lora_unet_prefix):
|
||||
continue
|
||||
|
||||
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
|
||||
|
||||
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
|
||||
for hf_module_name, sd_module_name in unet_conversion_map.items():
|
||||
if hf_module_name in unet_module_name:
|
||||
new_key = (
|
||||
lora_unet_prefix
|
||||
+ unet_module_name.replace(hf_module_name, sd_module_name)
|
||||
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
|
||||
)
|
||||
weights_sd[new_key] = weights_sd.pop(k)
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
logger.warning(f"Key {k} is not found in unet_conversion_map")
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
||||
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
|
@ -707,6 +814,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# if keys are Diffusers based, convert to SAI based
|
||||
if is_sdxl:
|
||||
convert_diffusers_to_sai_if_needed(weights_sd)
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
|
|
@ -730,19 +841,28 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
is_sdxl=is_sdxl,
|
||||
)
|
||||
|
||||
# block lr
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
||||
if block_lr_weight is not None:
|
||||
network.set_block_lr_weight(block_lr_weight)
|
||||
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
NUM_OF_MID_BLOCKS = 1
|
||||
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
|
||||
SDXL_NUM_OF_MID_BLOCKS = 3
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
|
|
@ -774,6 +894,7 @@ class LoRANetwork(torch.nn.Module):
|
|||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
varbose: Optional[bool] = False,
|
||||
is_sdxl: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
|
|
@ -794,6 +915,10 @@ class LoRANetwork(torch.nn.Module):
|
|||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
|
|
@ -855,7 +980,7 @@ class LoRANetwork(torch.nn.Module):
|
|||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
block_idx = get_block_index(lora_name, is_sdxl)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
|
|
@ -925,9 +1050,7 @@ class LoRANetwork(torch.nn.Module):
|
|||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
self.mid_lr_weight: float = None
|
||||
self.block_lr_weight = None
|
||||
self.block_lr = False
|
||||
|
||||
# assertion
|
||||
|
|
@ -958,12 +1081,12 @@ class LoRANetwork(torch.nn.Module):
|
|||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
|
|
@ -1004,81 +1127,117 @@ class LoRANetwork(torch.nn.Module):
|
|||
logger.info(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: List[float] = None,
|
||||
mid_lr_weight: float = None,
|
||||
down_lr_weight: List[float] = None,
|
||||
):
|
||||
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
|
||||
self.block_lr = True
|
||||
self.down_lr_weight = down_lr_weight
|
||||
self.mid_lr_weight = mid_lr_weight
|
||||
self.up_lr_weight = up_lr_weight
|
||||
self.block_lr_weight = block_lr_weight
|
||||
|
||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||
lr_weight = 1.0
|
||||
block_idx = get_block_index(lora.lora_name)
|
||||
if block_idx < 0:
|
||||
return lr_weight
|
||||
def get_lr_weight(self, block_idx: int) -> float:
|
||||
if not self.block_lr or self.block_lr_weight is None:
|
||||
return 1.0
|
||||
return self.block_lr_weight[block_idx]
|
||||
|
||||
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.down_lr_weight != None:
|
||||
lr_weight = self.down_lr_weight[block_idx]
|
||||
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.mid_lr_weight != None:
|
||||
lr_weight = self.mid_lr_weight
|
||||
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.up_lr_weight != None:
|
||||
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
return lr_weight
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||
# if (
|
||||
# self.loraplus_lr_ratio is not None
|
||||
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||
# or self.loraplus_unet_lr_ratio is not None
|
||||
# ):
|
||||
# assert (
|
||||
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
is_sdxl = False
|
||||
for lora in self.unet_loras:
|
||||
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
||||
is_sdxl = True
|
||||
break
|
||||
|
||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
block_idx_to_lora = {}
|
||||
for lora in self.unet_loras:
|
||||
idx = get_block_index(lora.lora_name)
|
||||
idx = get_block_index(lora.lora_name, is_sdxl)
|
||||
if idx not in block_idx_to_lora:
|
||||
block_idx_to_lora[idx] = []
|
||||
block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
block_loras,
|
||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
|
|
|
|||
|
|
@ -4,13 +4,17 @@ import math
|
|||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
import einops
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
|
@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module):
|
|||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
|
||||
# original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
|
||||
|
||||
self.out_dim = out_dim
|
||||
self.shape = org_module.weight.shape
|
||||
|
|
@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module):
|
|||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
|
||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
return R
|
||||
if self.I.device != block_Q.device:
|
||||
self.I = self.I.to(block_Q.device)
|
||||
I = self.I
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
|
||||
block_R_weighted = self.multiplier * (block_R - I) + I
|
||||
return block_R_weighted
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier == 0.0:
|
||||
return x
|
||||
return self.org_forward(x)
|
||||
org_module = self.org_module[0]
|
||||
org_dtype = x.dtype
|
||||
|
||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = torch.matmul(x, R)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
R = self.get_weight().to(torch.float32)
|
||||
W = org_module.weight.to(torch.float32)
|
||||
|
||||
if len(W.shape) == 4: # Conv2d
|
||||
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
|
||||
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
|
||||
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
|
||||
result = F.conv2d(
|
||||
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
|
||||
)
|
||||
else: # Linear
|
||||
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
|
||||
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
|
||||
RW = einops.rearrange(RW, "k m p -> (k m) p")
|
||||
result = F.linear(x, RW.to(org_dtype), org_module.bias)
|
||||
return result
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
|
|
@ -115,18 +133,19 @@ class OFTInfModule(OFTModule):
|
|||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None, sign=1):
|
||||
R = self.get_weight(multiplier) * sign
|
||||
|
||||
def merge_to(self, multiplier=None):
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
org_weight = org_sd["weight"].to(torch.float32)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
R = self.get_weight(multiplier).to(torch.float32)
|
||||
|
||||
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
|
||||
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
|
||||
weight = weight.reshape(org_weight.shape)
|
||||
|
||||
# convert back to original dtype
|
||||
weight = weight.to(org_sd["weight"].dtype)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
|
|
@ -145,8 +164,16 @@ def create_network(
|
|||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
if network_alpha is None: # should be set
|
||||
logger.info(
|
||||
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
|
||||
)
|
||||
network_alpha = 1e-3
|
||||
elif network_alpha >= 1:
|
||||
logger.warning(
|
||||
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
|
||||
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
|
||||
)
|
||||
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
enable_conv = kwargs.get("enable_conv", None)
|
||||
|
|
@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if has_conv2d is None and param.dim() == 4:
|
||||
if has_conv2d is None and "in_layers_2" in name:
|
||||
has_conv2d = True
|
||||
if all_linear is None:
|
||||
if param.dim() == 3 and "attn" not in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||
if all_linear is None and "_ff_" in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
|
||||
break
|
||||
if has_conv2d is None:
|
||||
has_conv2d = False
|
||||
|
|
@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module):
|
|||
self.alpha = alpha
|
||||
|
||||
logger.info(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
|
|
|
|||
|
|
@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype):
|
|||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(state_dict, file_name, metadata)
|
||||
else:
|
||||
|
|
@ -349,12 +344,18 @@ def resize(args):
|
|||
metadata["ss_network_dim"] = "Dynamic"
|
||||
metadata["ss_network_alpha"] = "Dynamic"
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
|
|
|||
|
|
@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype):
|
|||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
|
|
@ -430,6 +425,12 @@ def merge(args):
|
|||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
|
|
@ -445,7 +446,7 @@ def merge(args):
|
|||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
|
|
|||
|
|
@ -216,12 +216,7 @@ def load_state_dict(file_name, dtype):
|
|||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
|
|
@ -306,10 +301,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
|||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
|
|
@ -341,13 +336,16 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
|||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
merged_sd[lora_module_name] = weight.to("cpu")
|
||||
|
||||
# extract from merged weights
|
||||
logger.info("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
|
@ -386,7 +384,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
|||
|
||||
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{new_rank}"
|
||||
|
|
@ -430,6 +428,12 @@ def merge(args):
|
|||
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
|
|
@ -451,7 +455,7 @@ def merge(args):
|
|||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
accelerate==0.30.0
|
||||
transformers==4.44.0
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.8.1.78
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.43.0
|
||||
bitsandbytes==0.44.0
|
||||
prodigyopt==1.0
|
||||
lion-pytorch==0.0.6
|
||||
tensorboard
|
||||
|
|
@ -16,7 +16,7 @@ altair==4.2.2
|
|||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
huggingface-hub==0.24.5
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
# for BLIP captioning
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ def train(args):
|
|||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
train_unet = args.learning_rate > 0
|
||||
train_unet = args.learning_rate != 0
|
||||
train_text_encoder1 = False
|
||||
train_text_encoder2 = False
|
||||
|
||||
|
|
@ -284,8 +284,8 @@ def train(args):
|
|||
text_encoder2.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
train_text_encoder2 = lr_te2 > 0
|
||||
train_text_encoder1 = lr_te1 != 0
|
||||
train_text_encoder2 = lr_te2 != 0
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
|
|
@ -345,8 +345,8 @@ def train(args):
|
|||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
for group in params_to_optimize:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||
|
|
@ -355,7 +355,53 @@ def train(args):
|
|||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
if args.fused_optimizer_groups:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# calculate total number of parameters
|
||||
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
||||
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
||||
|
||||
# split params into groups, keeping the learning rate the same for all params in a group
|
||||
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
|
||||
grouped_params = []
|
||||
param_group = []
|
||||
param_group_lr = -1
|
||||
for group in params_to_optimize:
|
||||
lr = group["lr"]
|
||||
for p in group["params"]:
|
||||
# if the learning rate is different for different params, start a new group
|
||||
if lr != param_group_lr:
|
||||
if param_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
param_group = []
|
||||
param_group_lr = lr
|
||||
|
||||
param_group.append(p)
|
||||
|
||||
# if the group has enough parameters, start a new group
|
||||
if len(param_group) == params_per_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
param_group = []
|
||||
param_group_lr = -1
|
||||
|
||||
if param_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
||||
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
|
|
@ -382,7 +428,12 @@ def train(args):
|
|||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
if args.fused_optimizer_groups:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
|
|
@ -450,6 +501,57 @@ def train(args):
|
|||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||
|
||||
elif args.fused_optimizer_groups:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# counters are used to determine when to step the optimizer
|
||||
global optimizer_hooked_count
|
||||
global num_parameters_per_group
|
||||
global parameter_optimizer_map
|
||||
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def optimizer_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
|
@ -487,7 +589,11 @@ def train(args):
|
|||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
|
|
@ -504,6 +610,10 @@ def train(args):
|
|||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.fused_optimizer_groups:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
|
|
@ -582,7 +692,9 @@ def train(args):
|
|||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
|
|
@ -590,7 +702,11 @@ def train(args):
|
|||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
|
||||
target = noise
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if (
|
||||
args.min_snr_gamma
|
||||
|
|
@ -600,34 +716,46 @@ def train(args):
|
|||
or args.masked_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.fused_optimizer_groups:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
|
@ -736,7 +864,7 @@ def train(args):
|
|||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
|
@ -805,6 +933,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
||||
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
|
@ -288,6 +289,9 @@ def train(args):
|
|||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if isinstance(unet, DDP):
|
||||
unet._set_static_graph() # avoid error for multiple use of the parameter
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
|
|
@ -353,7 +357,7 @@ def train(args):
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
|
@ -439,7 +443,9 @@ def train(args):
|
|||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
|
|
@ -458,7 +464,9 @@ def train(args):
|
|||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
|
@ -471,13 +479,13 @@ def train(args):
|
|||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = unet.get_trainable_params()
|
||||
params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
|
|
|||
|
|
@ -324,7 +324,7 @@ def train(args):
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
|
@ -439,7 +439,7 @@ def train(args):
|
|||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ from library.config_util import (
|
|||
from library.utils import setup_logging, add_logging_arguments
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
setup_logging(args, reset=True)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
|
@ -109,7 +111,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
else:
|
||||
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
|
|
@ -138,6 +140,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
b_size = len(batch["images"])
|
||||
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||
flip_aug = batch["flip_aug"]
|
||||
alpha_mask = batch["alpha_mask"]
|
||||
random_crop = batch["random_crop"]
|
||||
bucket_reso = batch["bucket_reso"]
|
||||
|
||||
|
|
@ -156,14 +159,16 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
if train_util.is_disk_cached_latents_is_expected(
|
||||
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
|
||||
):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
|
||||
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import os
|
|||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, pil_resize
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -172,7 +172,10 @@ def process(args):
|
|||
if scale != 1.0:
|
||||
w = int(w * scale + .5)
|
||||
h = int(h * scale + .5)
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
||||
if scale < 1.0:
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
face_img = pil_resize(face_img, (w, h))
|
||||
cx = int(cx * scale + .5)
|
||||
cy = int(cy * scale + .5)
|
||||
fw = int(fw * scale + .5)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import shutil
|
|||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, pil_resize
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||
|
||||
# Select interpolation method
|
||||
if interpolation == 'lanczos4':
|
||||
cv2_interpolation = cv2.INTER_LANCZOS4
|
||||
pil_interpolation = Image.LANCZOS
|
||||
elif interpolation == 'cubic':
|
||||
cv2_interpolation = cv2.INTER_CUBIC
|
||||
pil_interpolation = Image.BICUBIC
|
||||
else:
|
||||
cv2_interpolation = cv2.INTER_AREA
|
||||
|
||||
|
|
@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
if cv2_interpolation:
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
else:
|
||||
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
|
||||
else:
|
||||
new_height, new_width = img.shape[0:2]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import os
|
|||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
|
||||
# from omegaconf import OmegaConf
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
@ -13,6 +14,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
|
@ -105,6 +107,8 @@ def train(args):
|
|||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
|
@ -148,8 +152,10 @@ def train(args):
|
|||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": [5, 10, 20, 20],
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 4,
|
||||
|
|
@ -179,8 +185,10 @@ def train(args):
|
|||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": 8,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
|
|
@ -193,7 +201,23 @@ def train(args):
|
|||
"resnet_time_scale_shift": "default",
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
}
|
||||
unet.config = SimpleNamespace(**unet.config)
|
||||
# unet.config = OmegaConf.create(unet.config)
|
||||
|
||||
# make unet.config iterable and accessible by attribute
|
||||
class CustomConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self.__dict__
|
||||
|
||||
unet.config = CustomConfig(**unet.config)
|
||||
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
|
|
@ -226,7 +250,7 @@ def train(args):
|
|||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
|
|
@ -235,7 +259,7 @@ def train(args):
|
|||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = controlnet.parameters()
|
||||
trainable_params = list(controlnet.parameters())
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
|
|
@ -344,7 +368,9 @@ def train(args):
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
|
@ -420,7 +446,9 @@ def train(args):
|
|||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
|
||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
|
||||
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
|
||||
)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
|
@ -452,7 +480,9 @@ def train(args):
|
|||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ def train(args):
|
|||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
|
@ -290,7 +292,7 @@ def train(args):
|
|||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
|
@ -359,7 +361,7 @@ def train(args):
|
|||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
|
@ -371,7 +373,7 @@ def train(args):
|
|||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,15 @@ class NetworkTrainer:
|
|||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(
|
||||
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
|
||||
self,
|
||||
args: argparse.Namespace,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
|
|
@ -63,39 +71,31 @@ class NetworkTrainer:
|
|||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
for i, lr in enumerate(lrs):
|
||||
if lr_descriptions is not None:
|
||||
lr_desc = lr_descriptions[i]
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
idx = i - (0 if args.network_train_unet_only else -1)
|
||||
if idx == -1:
|
||||
lr_desc = "textencoder"
|
||||
else:
|
||||
if len(lrs) > 2:
|
||||
lr_desc = f"group{idx}"
|
||||
else:
|
||||
lr_desc = "unet"
|
||||
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
logs[f"lr/{lr_desc}"] = lr
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
# tracking d*lr value
|
||||
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
pass
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
|
@ -323,6 +323,7 @@ class NetworkTrainer:
|
|||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
# FIXME consider alpha of weights
|
||||
info = network.load_weights(args.network_weights)
|
||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
|
|
@ -338,12 +339,30 @@ class NetworkTrainer:
|
|||
|
||||
# 後方互換性を確保するよ
|
||||
try:
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
except TypeError:
|
||||
accelerator.print(
|
||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
)
|
||||
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
if type(results) is tuple:
|
||||
trainable_params = results[0]
|
||||
lr_descriptions = results[1]
|
||||
else:
|
||||
trainable_params = results
|
||||
lr_descriptions = None
|
||||
except TypeError as e:
|
||||
# logger.warning(f"{e}")
|
||||
# accelerator.print(
|
||||
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
# )
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
lr_descriptions = None
|
||||
|
||||
# if len(trainable_params) == 0:
|
||||
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
|
||||
# for params in trainable_params:
|
||||
# for k, v in params.items():
|
||||
# if type(v) == float:
|
||||
# pass
|
||||
# else:
|
||||
# v = len(v)
|
||||
# accelerator.print(f"trainable_params: {k} = {v}")
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
|
|
@ -474,15 +493,26 @@ class NetworkTrainer:
|
|||
# before resuming make hook for saving/loading to save/load the network weights only
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# pop weights of other models than network to save only network weights
|
||||
if accelerator.is_main_process:
|
||||
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
||||
if accelerator.is_main_process or args.deepspeed:
|
||||
remove_indices = []
|
||||
for i, model in enumerate(models):
|
||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||
remove_indices.append(i)
|
||||
for i in reversed(remove_indices):
|
||||
weights.pop(i)
|
||||
if len(weights) > i:
|
||||
weights.pop(i)
|
||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||
|
||||
# save current ecpoch and step
|
||||
train_state_file = os.path.join(output_dir, "train_state.json")
|
||||
# +1 is needed because the state is saved before current_step is set from global_step
|
||||
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
||||
with open(train_state_file, "w", encoding="utf-8") as f:
|
||||
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
||||
|
||||
steps_from_state = None
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
# remove models except network
|
||||
remove_indices = []
|
||||
|
|
@ -493,6 +523,15 @@ class NetworkTrainer:
|
|||
models.pop(i)
|
||||
# print(f"load model hook: {len(models)} models will be loaded")
|
||||
|
||||
# load current epoch and step to
|
||||
nonlocal steps_from_state
|
||||
train_state_file = os.path.join(input_dir, "train_state.json")
|
||||
if os.path.exists(train_state_file):
|
||||
with open(train_state_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
steps_from_state = data["current_step"]
|
||||
logger.info(f"load train state from {train_state_file}: {data}")
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
|
|
@ -736,7 +775,54 @@ class NetworkTrainer:
|
|||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
# calculate steps to skip when resuming or starting from a specific step
|
||||
initial_step = 0
|
||||
if args.initial_epoch is not None or args.initial_step is not None:
|
||||
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
||||
if steps_from_state is not None:
|
||||
logger.warning(
|
||||
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
||||
)
|
||||
if args.initial_step is not None:
|
||||
initial_step = args.initial_step
|
||||
else:
|
||||
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
||||
initial_step = (args.initial_epoch - 1) * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
||||
if steps_from_state is not None:
|
||||
initial_step = steps_from_state
|
||||
steps_from_state = None
|
||||
|
||||
if initial_step > 0:
|
||||
assert (
|
||||
args.max_train_steps > initial_step
|
||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
epoch_to_start = 0
|
||||
if initial_step > 0:
|
||||
if args.skip_until_initial_step:
|
||||
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
||||
if not args.resume:
|
||||
logger.info(
|
||||
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
||||
)
|
||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||
initial_step *= args.gradient_accumulation_steps
|
||||
|
||||
# set epoch to start to make initial_step less than len(train_dataloader)
|
||||
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
else:
|
||||
# if not, only epoch no is skipped for informative purpose
|
||||
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
initial_step = 0 # do not skip
|
||||
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
|
|
@ -753,7 +839,9 @@ class NetworkTrainer:
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
|
@ -793,7 +881,13 @@ class NetworkTrainer:
|
|||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
global_step = initial_step
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
|
|
@ -801,8 +895,17 @@ class NetworkTrainer:
|
|||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
skipped_dataloader = None
|
||||
if initial_step > 0:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
initial_step = 1
|
||||
|
||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
|
|
@ -881,7 +984,7 @@ class NetworkTrainer:
|
|||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss:
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
|
@ -895,7 +998,7 @@ class NetworkTrainer:
|
|||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
|
@ -950,7 +1053,9 @@ class NetworkTrainer:
|
|||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
|
|
@ -1101,6 +1206,28 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_until_initial_step",
|
||||
action="store_true",
|
||||
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_epoch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
||||
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class TextualInversionTrainer:
|
|||
self.is_sdxl = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
pass
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
|
@ -510,7 +510,7 @@ class TextualInversionTrainer:
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
|
|
@ -589,7 +589,7 @@ class TextualInversionTrainer:
|
|||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
|
@ -603,7 +603,7 @@ class TextualInversionTrainer:
|
|||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
|
|
|||
|
|
@ -407,7 +407,7 @@ def train(args):
|
|||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
|
|
@ -474,7 +474,7 @@ def train(args):
|
|||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
|
@ -486,7 +486,7 @@ def train(args):
|
|||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue