mirror of https://github.com/bmaltais/kohya_ss
Merge branch 'dev2' into cleanup
commit
9619a8214f
|
|
@ -0,0 +1,7 @@
|
|||
---
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
|
@ -15,7 +15,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.13.10
|
||||
uses: crate-ci/typos@v1.16.15
|
||||
|
|
|
|||
39
README.md
39
README.md
|
|
@ -627,6 +627,32 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||
|
||||
## Change History
|
||||
|
||||
* 2023/10/10 (v22.1.0)
|
||||
- Remove support for torch 1 to align with kohya_ss sd-scripts code base.
|
||||
- Add Intel ARC GPU support with IPEX support on Linuix / WSL
|
||||
- Users needs to set these manually:
|
||||
* Mixed precision to BF16,
|
||||
* Attention to SDPA,
|
||||
* Optimizer to: AdamW (or any other non 8 bit one).
|
||||
- Run setup with: `./setup.sh --use-ipex`
|
||||
- Run the GUI with: `./gui.sh --use-ipex`
|
||||
- Merging main branch of sd-scripts:
|
||||
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
|
||||
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
|
||||
- Please install Onnx and other required packages.
|
||||
1. Uninstall TensorFlow.
|
||||
2. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
|
||||
3. `pip install protobuf==3.20.3` This is required for Onnx.
|
||||
4. `pip install onnx==1.14.1`
|
||||
5. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
|
||||
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
|
||||
- [OFT](https://oft.wyliu.com/) is now supported.
|
||||
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
|
||||
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
|
||||
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
|
||||
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
|
||||
- Other bug fixes and improvements.
|
||||
|
||||
* 2023/10/01 (v22.0.0)
|
||||
- Merging main branch of sd-scripts:
|
||||
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
|
||||
|
|
@ -650,16 +676,3 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||
- Update wandb module version
|
||||
- Add support for Chinese zh-CN localisation. You can use it with `.\gui.bat --language=zh-CN`
|
||||
- Add presets support to `Finetuning`. You can add your own finetuning user presets under the `/presets/finetune/user_presets` folder.
|
||||
|
||||
* 2023/09/23 (v21.8.10)
|
||||
- Minor point upgrade. Mostly adding a new preset.
|
||||
|
||||
* 2023/08/05 (v21.8.9)
|
||||
- Update sd-script to caode as of Sept 3 2023
|
||||
* ControlNet-LLLite is added. See documentation for details.
|
||||
* JPEG XL is supported. #786
|
||||
* Peak memory usage is reduced. #791
|
||||
* Input perturbation noise is added. See #798 for details.
|
||||
* Dataset subset now has caption_prefix and caption_suffix options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in .toml.
|
||||
* Other minor changes.
|
||||
- Added support for Chinese locallisation
|
||||
|
|
|
|||
|
|
@ -0,0 +1,188 @@
|
|||
嗨!我把日语 README 文件的主要内容翻译成中文如下:
|
||||
|
||||
## 关于这个仓库
|
||||
|
||||
这个是用于Stable Diffusion模型训练、图像生成和其他脚本的仓库。
|
||||
|
||||
[英文版 README](./README.md) <-- 更新信息在这里
|
||||
|
||||
GUI和PowerShell脚本等使其更易用的功能在[bmaltais的仓库](https://github.com/bmaltais/kohya_ss)(英语)中提供,一并参考。感谢bmaltais。
|
||||
|
||||
包含以下脚本:
|
||||
|
||||
* 支持DreamBooth、U-Net和文本编码器的训练
|
||||
* fine-tuning的支持
|
||||
* 图像生成
|
||||
* 模型转换(Stable Diffusion ckpt/safetensors 和 Diffusers之间的相互转换)
|
||||
|
||||
## 使用方法 (中国用户只需要按照这个安装教程操作)
|
||||
- 进入kohya_ss文件夹根目录下,点击 setup.bat 启动安装程序 *(需要科学上网)
|
||||
- 根据界面上给出的英文选项:
|
||||
Kohya_ss GUI setup menu:
|
||||
|
||||
1. Install kohya_ss gui
|
||||
2. (Optional) Install cudann files (avoid unless you really need it)
|
||||
3. (Optional) Install specific bitsandbytes versions
|
||||
4. (Optional) Manually configure accelerate
|
||||
5. (Optional) Start Kohya_ss GUI in browser
|
||||
6. Quit
|
||||
|
||||
Enter your choice: 1
|
||||
|
||||
1. Torch 1 (legacy, no longer supported. Will be removed in v21.9.x)
|
||||
2. Torch 2 (recommended)
|
||||
3. Cancel
|
||||
|
||||
Enter your choice: 2
|
||||
|
||||
开始安装环境依赖,接着再出来的选项,按照下列选项操作:
|
||||
```txt
|
||||
- This machine
|
||||
- No distributed training
|
||||
- NO
|
||||
- NO
|
||||
- NO
|
||||
- all
|
||||
- bf16
|
||||
```
|
||||
--------------------------------------------------------------------
|
||||
这里都选择完毕,即可关闭终端窗口,直接点击 gui.bat或者 kohya中文启动器.bat 即可运行kohya
|
||||
|
||||
|
||||
当仓库内和note.com有相关文章,请参考那里。(未来可能全部移到这里)
|
||||
|
||||
* [关于训练,通用篇](./docs/train_README-ja.md): 数据准备和选项等
|
||||
* [数据集设置](./docs/config_README-ja.md)
|
||||
* [DreamBooth训练指南](./docs/train_db_README-ja.md)
|
||||
* [fine-tuning指南](./docs/fine_tune_README_ja.md)
|
||||
* [LoRA训练指南](./docs/train_network_README-ja.md)
|
||||
* [文本反转训练指南](./docs/train_ti_README-ja.md)
|
||||
* [图像生成脚本](./docs/gen_img_README-ja.md)
|
||||
* note.com [模型转换脚本](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windows环境所需程序
|
||||
|
||||
需要Python 3.10.6和Git。
|
||||
|
||||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
如果要在PowerShell中使用venv,需要按以下步骤更改安全设置:
|
||||
(不仅仅是venv,使脚本可以执行。请注意。)
|
||||
|
||||
- 以管理员身份打开PowerShell
|
||||
- 输入"Set-ExecutionPolicy Unrestricted",选择Y
|
||||
- 关闭管理员PowerShell
|
||||
|
||||
## 在Windows环境下安装
|
||||
|
||||
下例中安装的是PyTorch 1.12.1/CUDA 11.6版。如果要使用CUDA 11.3或PyTorch 1.13,请适当修改。
|
||||
|
||||
(如果只显示"python",请将下例中的"python"改为"py")
|
||||
|
||||
在普通(非管理员)PowerShell中依次执行以下命令:
|
||||
|
||||
```powershell
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
在命令提示符中:
|
||||
|
||||
```bat
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
accelerate config的问题请按以下回答:
|
||||
(如果要用bf16训练,最后一个问题选择bf16)
|
||||
|
||||
```
|
||||
- 此计算机
|
||||
- 不进行分布式训练
|
||||
- 否
|
||||
- 否
|
||||
- 否
|
||||
- 所有
|
||||
- fp16
|
||||
```
|
||||
|
||||
### PyTorch和xformers版本注意事项
|
||||
|
||||
在其他版本中训练可能失败。如果没有特殊原因,请使用指定版本。
|
||||
|
||||
|
||||
### 可选:使用Lion8bit
|
||||
|
||||
如果要使用Lion8bit,需要将`bitsandbytes`升级到0.38.0以上。首先卸载`bitsandbytes`,然后在Windows中安装适合Windows的whl文件,例如[这里的](https://github.com/jllllll/bitsandbytes-windows-webui)。例如:
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
|
||||
```
|
||||
|
||||
升级时用`pip install .`更新这个仓库,并视情况升级其他包。
|
||||
|
||||
### 可选:使用PagedAdamW8bit和PagedLion8bit
|
||||
|
||||
如果要使用PagedAdamW8bit和PagedLion8bit,需要将`bitsandbytes`升级到0.39.0以上。首先卸载`bitsandbytes`,然后在Windows中安装适合Windows的whl文件,例如[这里的](https://github.com/jllllll/bitsandbytes-windows-webui)。例如:
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
升级时用`pip install .`更新这个仓库,并视情况升级其他包。
|
||||
|
||||
## 升级
|
||||
|
||||
如果有新版本,可以用以下命令更新:
|
||||
|
||||
```powershell
|
||||
cd sd-scripts
|
||||
git pull
|
||||
.\venv\Scripts\activate
|
||||
pip install --use-pep517 --upgrade -r requirements.txt
|
||||
```
|
||||
|
||||
如果命令成功,就可以使用新版本了。
|
||||
|
||||
## 致谢
|
||||
|
||||
LoRA实现基于[cloneofsimo的仓库](https://github.com/cloneofsimo/lora)。表示感谢。
|
||||
|
||||
将Conv2d 3x3扩展到所有层起初由 [cloneofsimo](https://github.com/cloneofsimo/lora) 发布, [KohakuBlueleaf](https://github.com/KohakuBlueleaf/LoCon) 证明了其有效性。深深感谢 KohakuBlueleaf。
|
||||
|
||||
## 许可
|
||||
|
||||
脚本遵循 ASL 2.0 许可,但包含其他许可的代码部分(Diffusers和cloneofsimo的仓库)。
|
||||
|
||||
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
||||
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
20
_typos.toml
20
_typos.toml
|
|
@ -9,7 +9,25 @@ parms="parms"
|
|||
nin="nin"
|
||||
extention="extention" # Intentionally left
|
||||
nd="nd"
|
||||
shs="shs"
|
||||
sts="sts"
|
||||
scs="scs"
|
||||
cpc="cpc"
|
||||
coc="coc"
|
||||
cic="cic"
|
||||
msm="msm"
|
||||
usu="usu"
|
||||
ici="ici"
|
||||
lvl="lvl"
|
||||
dii="dii"
|
||||
muk="muk"
|
||||
ori="ori"
|
||||
hru="hru"
|
||||
rik="rik"
|
||||
koo="koo"
|
||||
yos="yos"
|
||||
wn="wn"
|
||||
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml"]
|
||||
extend-exclude = ["_typos.toml", "venv"]
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from library.utilities import utilities_tab
|
|||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
from localization_ext import add_javascript
|
||||
from library.localization_ext import add_javascript
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
|
|||
|
|
@ -80,8 +80,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
|
@ -208,7 +208,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
# from wd14 tagger
|
||||
|
|
@ -20,6 +18,7 @@ IMAGE_SIZE = 448
|
|||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
FILES_ONNX = ["model.onnx"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
|
@ -81,7 +80,10 @@ def main(args):
|
|||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
for file in FILES:
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files += FILES_ONNX
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
|
|
@ -96,7 +98,46 @@ def main(args):
|
|||
print("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
model = load_model(args.model_dir)
|
||||
if args.onnx:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||
)
|
||||
|
||||
model = onnx.load(onnx_path)
|
||||
input_name = model.graph.input[0].name
|
||||
try:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||
except:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
print(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
|
||||
del model
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{args.model_dir}")
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
|
@ -124,8 +165,14 @@ def main(args):
|
|||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
if args.onnx:
|
||||
if len(imgs) < args.batch_size:
|
||||
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
|
|
@ -165,9 +212,27 @@ def main(args):
|
|||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[2:]
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
tag_text = ", ".join(combined_tags)
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
|
||||
|
||||
# Check and remove repeating tags in tag_text
|
||||
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
|
||||
|
||||
# Create new tag_text
|
||||
tag_text = ", ".join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
|
|
@ -283,12 +348,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from library.utilities import utilities_tab
|
|||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
from localization_ext import add_javascript
|
||||
from library.localization_ext import add_javascript
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
|
|||
|
|
@ -3364,7 +3364,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
)
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
|
|
@ -3390,7 +3390,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_guidance_scale",
|
||||
|
|
@ -3449,7 +3449,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
"--highres_fix_upscaler_args",
|
||||
type=str,
|
||||
default=None,
|
||||
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
help="additional arguments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_disable_control_net",
|
||||
|
|
|
|||
22
gui.sh
22
gui.sh
|
|
@ -59,12 +59,32 @@ if [[ "$OSTYPE" == "darwin"* ]]; then
|
|||
fi
|
||||
else
|
||||
if [ "$RUNPOD" = false ]; then
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt"
|
||||
if [[ "$@" == *"--use-ipex"* ]]; then
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt"
|
||||
else
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt"
|
||||
fi
|
||||
else
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_runpod.txt"
|
||||
fi
|
||||
fi
|
||||
|
||||
#Set OneAPI if it's not set by the user
|
||||
if [[ "$@" == *"--use-ipex"* ]]
|
||||
then
|
||||
echo "Setting OneAPI environment"
|
||||
if [ ! -x "$(command -v sycl-ls)" ]
|
||||
then
|
||||
if [[ -z "$ONEAPI_ROOT" ]]
|
||||
then
|
||||
ONEAPI_ROOT=/opt/intel/oneapi
|
||||
fi
|
||||
source $ONEAPI_ROOT/setvars.sh
|
||||
fi
|
||||
export NEOReadDebugKeys=1
|
||||
export ClDeviceGlobalMemSizeAvailablePercent=100
|
||||
fi
|
||||
|
||||
# Validate the requirements and run the script if successful
|
||||
if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then
|
||||
python "$SCRIPT_DIR/kohya_gui.py" "$@"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from library.class_lora_tab import LoRATools
|
|||
|
||||
import os
|
||||
from library.custom_logging import setup_logging
|
||||
from localization_ext import add_javascript
|
||||
from library.localization_ext import add_javascript
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -133,6 +133,10 @@ if __name__ == '__main__':
|
|||
'--language', type=str, default=None, help='Set custom language'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--use-ipex', action='store_true', help='Use IPEX environment'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
import localization
|
||||
import library.localization as localization
|
||||
|
||||
|
||||
def file_path(fn):
|
||||
|
|
@ -16,7 +16,7 @@ def js_html_str(language):
|
|||
|
||||
def add_javascript(language):
|
||||
if language is None:
|
||||
print('no language')
|
||||
# print('no language')
|
||||
return
|
||||
jsStr = js_html_str(language)
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo
|
|||
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
||||
|
||||
|
||||
# region memory effcient attention
|
||||
# region memory efficient attention
|
||||
|
||||
# FlashAttentionを使うCrossAttention
|
||||
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ TIME_EMBED_DIM = 320 * 4
|
|||
|
||||
USE_REENTRANT = True
|
||||
|
||||
# region memory effcient attention
|
||||
# region memory efficient attention
|
||||
|
||||
# FlashAttentionを使うCrossAttention
|
||||
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Linux
|
||||
try:
|
||||
from jxlpy import JXLImagePlugin
|
||||
|
||||
|
|
@ -103,6 +104,14 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Windows
|
||||
try:
|
||||
import pillow_jxl
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
pass
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
|
|
@ -4658,7 +4667,7 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collater_class:
|
||||
class collator_class:
|
||||
def __init__(self, epoch, step, dataset):
|
||||
self.current_epoch = epoch
|
||||
self.current_step = step
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ def caption_images(
|
|||
frequency_tags,
|
||||
prefix,
|
||||
postfix,
|
||||
onnx,
|
||||
append_tags,
|
||||
force_download
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == '':
|
||||
|
|
@ -54,6 +57,12 @@ def caption_images(
|
|||
run_cmd += f' --remove_underscore'
|
||||
if frequency_tags:
|
||||
run_cmd += f' --frequency_tags'
|
||||
if onnx:
|
||||
run_cmd += f' --onnx'
|
||||
if append_tags:
|
||||
run_cmd += f' --append_tags'
|
||||
if force_download:
|
||||
run_cmd += f' --force_download'
|
||||
|
||||
if not undesired_tags == '':
|
||||
run_cmd += f' --undesired_tags="{undesired_tags}"'
|
||||
|
|
@ -132,6 +141,20 @@ def gradio_wd14_caption_gui_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
onnx = gr.Checkbox(
|
||||
label='Use onnx',
|
||||
value=False,
|
||||
interactive=True,
|
||||
info="https://github.com/onnx/onnx"
|
||||
)
|
||||
append_tags = gr.Checkbox(
|
||||
label='Append TAGs',
|
||||
value=False,
|
||||
interactive=True,
|
||||
info="This option appends the tags to the existing tags, instead of replacing them."
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
replace_underscores = gr.Checkbox(
|
||||
label='Replace underscores in filenames with spaces',
|
||||
|
|
@ -168,6 +191,12 @@ def gradio_wd14_caption_gui_tab(headless=False):
|
|||
],
|
||||
value='SmilingWolf/wd-v1-4-convnextv2-tagger-v2',
|
||||
)
|
||||
|
||||
force_download = gr.Checkbox(
|
||||
label='Force model re-download',
|
||||
value=False,
|
||||
info='Usefull to force model re download when switching to onnx',
|
||||
)
|
||||
|
||||
general_threshold = gr.Slider(
|
||||
value=0.35,
|
||||
|
|
@ -215,6 +244,9 @@ def gradio_wd14_caption_gui_tab(headless=False):
|
|||
frequency_tags,
|
||||
prefix,
|
||||
postfix,
|
||||
onnx,
|
||||
append_tags,
|
||||
force_download
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from library.dreambooth_folder_creation_gui import (
|
|||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
from localization_ext import add_javascript
|
||||
from library.localization_ext import add_javascript
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -735,7 +735,7 @@ def train_model(
|
|||
)
|
||||
return
|
||||
run_cmd += f' --network_module=lycoris.kohya'
|
||||
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
|
||||
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=locon"'
|
||||
|
||||
if LoRA_type == 'LyCORIS/LoHa':
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,430 @@
|
|||
# OFT network module
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class OFTModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""
|
||||
dim -> num blocks
|
||||
alpha -> constraint
|
||||
"""
|
||||
super().__init__()
|
||||
self.oft_name = oft_name
|
||||
|
||||
self.num_blocks = dim
|
||||
|
||||
if "Linear" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_features
|
||||
elif "Conv" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.constraint = alpha * out_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||
|
||||
self.out_dim = out_dim
|
||||
self.shape = org_module.weight.shape
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
|
||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier == 0.0:
|
||||
return x
|
||||
|
||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = torch.matmul(x, R)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(oft_name, org_module, multiplier, dim, alpha)
|
||||
self.enabled = True
|
||||
self.network: OFTNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None, sign=1):
|
||||
R = self.get_weight(multiplier) * sign
|
||||
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
enable_conv = kwargs.get("enable_conv", None)
|
||||
if enable_all_linear is not None:
|
||||
enable_all_linear = bool(enable_all_linear)
|
||||
if enable_conv is not None:
|
||||
enable_conv = bool(enable_conv)
|
||||
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
enable_all_linear=enable_all_linear,
|
||||
enable_conv=enable_conv,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# check dim, alpha and if weights have for conv2d
|
||||
dim = None
|
||||
alpha = None
|
||||
has_conv2d = None
|
||||
all_linear = None
|
||||
for name, param in weights_sd.items():
|
||||
if name.endswith(".alpha"):
|
||||
if alpha is None:
|
||||
alpha = param.item()
|
||||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if has_conv2d is None and param.dim() == 4:
|
||||
has_conv2d = True
|
||||
if all_linear is None:
|
||||
if param.dim() == 3 and "attn" not in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||
break
|
||||
if has_conv2d is None:
|
||||
has_conv2d = False
|
||||
if all_linear is None:
|
||||
all_linear = False
|
||||
|
||||
module_class = OFTInfModule if for_inference else OFTModule
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=dim,
|
||||
alpha=alpha,
|
||||
enable_all_linear=all_linear,
|
||||
enable_conv=has_conv2d,
|
||||
module_class=module_class,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class OFTNetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
|
||||
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
dim: int = 4,
|
||||
alpha: float = 1,
|
||||
enable_all_linear: Optional[bool] = False,
|
||||
enable_conv: Optional[bool] = False,
|
||||
module_class: Type[object] = OFTModule,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
print(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[OFTModule]:
|
||||
prefix = self.OFT_PREFIX_UNET
|
||||
ofts = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = "Linear" in child_module.__class__.__name__
|
||||
is_conv2d = "Conv2d" in child_module.__class__.__name__
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# print(oft_name)
|
||||
|
||||
oft = module_class(
|
||||
oft_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
)
|
||||
ofts.append(oft)
|
||||
return ofts
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
if enable_all_linear:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
|
||||
else:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
|
||||
if enable_conv:
|
||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for oft in self.unet_ofts:
|
||||
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
||||
names.add(oft.oft_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for oft in self.unet_ofts:
|
||||
oft.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
assert apply_unet, "apply_unet must be True"
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
oft.apply_to()
|
||||
self.add_module(oft.oft_name, oft)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
print("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(oft.oft_name):
|
||||
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(ofts):
|
||||
params = []
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# print num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
print(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
oft.merge_to()
|
||||
# sd = org_module.state_dict()
|
||||
# org_weight = sd["weight"]
|
||||
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
# sd["weight"] = org_weight + lora_weight
|
||||
# assert sd["weight"].shape == org_weight.shape
|
||||
# org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
oft.enabled = False
|
||||
|
|
@ -13,13 +13,20 @@ huggingface-hub==0.15.1
|
|||
# for loading Diffusers' SDXL
|
||||
invisible-watermark==0.2.0
|
||||
lion-pytorch==0.0.6
|
||||
lycoris_lora==1.8.3
|
||||
lycoris_lora==1.9.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
# fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# tensorflow==2.10.1
|
||||
# for WD14 captioning (tensorflow)
|
||||
# tensorflow==2.14.0
|
||||
# for WD14 captioning (onnx)
|
||||
onnx==1.14.1
|
||||
onnxruntime-gpu==1.16.0
|
||||
# onnxruntime==1.16.0
|
||||
# this is for onnx:
|
||||
# tensorboard==2.14.1
|
||||
protobuf==3.20.3
|
||||
# open clip for SDXL
|
||||
open-clip-torch==2.20.0
|
||||
opencv-python==4.7.0.68
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
||||
xformers==0.0.21 bitsandbytes==0.41.1
|
||||
tensorboard==2.12.3 tensorflow==2.12.0
|
||||
tensorboard==2.14.1 tensorflow==2.14.0
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
xformers==0.0.20
|
||||
bitsandbytes==0.41.1
|
||||
accelerate==0.19.0
|
||||
tensorboard==2.12.1
|
||||
tensorflow==2.12.0
|
||||
tensorboard==2.14.1
|
||||
tensorflow==2.14.0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
tensorboard==2.14.1 tensorflow==2.14.0 intel-extension-for-tensorflow[gpu]
|
||||
-r requirements.txt
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.41.1
|
||||
tensorflow-macos tensorboard==2.12.1
|
||||
tensorflow-macos tensorboard==2.14.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.41.1
|
||||
tensorflow-macos tensorflow-metal tensorboard==2.12.1
|
||||
tensorflow-macos tensorflow-metal tensorboard==2.14.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
||||
xformers==0.0.21 bitsandbytes==0.41.1
|
||||
tensorboard==2.12.3 tensorflow==2.12.0 wheel
|
||||
tensorboard==2.14.1 tensorflow==2.14.0 wheel
|
||||
tensorrt
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -2,5 +2,5 @@ torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorc
|
|||
xformers==0.0.21
|
||||
bitsandbytes==0.35.0 # no_verify
|
||||
# https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl # no_verify
|
||||
tensorboard==2.12.3 tensorflow==2.12.0
|
||||
tensorboard==2.14.1 tensorflow==2.14.0
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -2612,7 +2612,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
)
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
|
|
@ -2631,7 +2631,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像"
|
||||
|
|
@ -2666,7 +2666,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
"--highres_fix_upscaler_args",
|
||||
type=str,
|
||||
default=None,
|
||||
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
help="additional arguments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_disable_control_net",
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
||||
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -172,8 +172,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
|
|
@ -348,7 +348,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,8 +106,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -102,8 +102,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
|
|
@ -213,7 +213,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
5
setup.sh
5
setup.sh
|
|
@ -27,6 +27,7 @@ Options:
|
|||
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
||||
-u, --no-gui Skips launching the GUI.
|
||||
-v, --verbose Increase verbosity levels up to 3.
|
||||
--use-ipex Use IPEX with Intel ARC GPUs.
|
||||
EOF
|
||||
}
|
||||
|
||||
|
|
@ -87,6 +88,7 @@ MAXVERBOSITY=6
|
|||
DIR=""
|
||||
PARENT_DIR=""
|
||||
VENV_DIR=""
|
||||
USE_IPEX=false
|
||||
|
||||
# Function to get the distro name
|
||||
get_distro_name() {
|
||||
|
|
@ -203,6 +205,8 @@ install_python_dependencies() {
|
|||
"lin"*)
|
||||
if [ "$RUNPOD" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
|
||||
elif [ "$USE_IPEX" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
|
||||
else
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
|
||||
fi
|
||||
|
|
@ -318,6 +322,7 @@ while getopts ":vb:d:g:inprus-:" opt; do
|
|||
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
|
||||
u | no-gui) SKIP_GUI=true ;;
|
||||
v) ((VERBOSITY = VERBOSITY + 1)) ;;
|
||||
use-ipex) USE_IPEX=true ;;
|
||||
h) display_help && exit 0 ;;
|
||||
*) display_help && exit 0 ;;
|
||||
esac
|
||||
|
|
|
|||
|
|
@ -195,12 +195,24 @@ def check_torch():
|
|||
'/opt/rocm/bin/rocminfo'
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
os.environ.setdefault('NEOReadDebugKeys', '1')
|
||||
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
# Check if CUDA is available
|
||||
|
|
@ -208,10 +220,14 @@ def check_torch():
|
|||
log.warning('Torch reports CUDA not available')
|
||||
else:
|
||||
if torch.version.cuda:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX OneAPI version
|
||||
log.info(f'Torch backend: Intel IPEX OneAPI {ipex.__version__}')
|
||||
else:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
|
|
@ -222,9 +238,14 @@ def check_torch():
|
|||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
# log.warning(f'Could not load torch: {e}')
|
||||
|
|
|
|||
|
|
@ -98,31 +98,31 @@ def sync_bits_and_bytes_files():
|
|||
log.error(f'An unexpected error occurred: {e}')
|
||||
|
||||
|
||||
def install_kohya_ss_torch1():
|
||||
setup_common.check_repo_version()
|
||||
setup_common.check_python()
|
||||
# def install_kohya_ss_torch1():
|
||||
# setup_common.check_repo_version()
|
||||
# setup_common.check_python()
|
||||
|
||||
# Upgrade pip if needed
|
||||
setup_common.install('--upgrade pip')
|
||||
# # Upgrade pip if needed
|
||||
# setup_common.install('--upgrade pip')
|
||||
|
||||
if setup_common.check_torch() == 2:
|
||||
input(
|
||||
f'{YELLOW}\nTorch 2 is already installed in the venv. To install Torch 1 delete the venv and re-run setup.bat\n\nHit enter to continue...{RESET_COLOR}'
|
||||
)
|
||||
return
|
||||
# if setup_common.check_torch() == 2:
|
||||
# input(
|
||||
# f'{YELLOW}\nTorch 2 is already installed in the venv. To install Torch 1 delete the venv and re-run setup.bat\n\nHit enter to continue...{RESET_COLOR}'
|
||||
# )
|
||||
# return
|
||||
|
||||
# setup_common.install(
|
||||
# 'torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116',
|
||||
# 'torch torchvision'
|
||||
# )
|
||||
# setup_common.install(
|
||||
# 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps',
|
||||
# 'xformers-0.0.14'
|
||||
# )
|
||||
setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=False)
|
||||
sync_bits_and_bytes_files()
|
||||
setup_common.configure_accelerate(run_accelerate=True)
|
||||
# run_cmd(f'accelerate config')
|
||||
# # setup_common.install(
|
||||
# # 'torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116',
|
||||
# # 'torch torchvision'
|
||||
# # )
|
||||
# # setup_common.install(
|
||||
# # 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps',
|
||||
# # 'xformers-0.0.14'
|
||||
# # )
|
||||
# setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=False)
|
||||
# sync_bits_and_bytes_files()
|
||||
# setup_common.configure_accelerate(run_accelerate=True)
|
||||
# # run_cmd(f'accelerate config')
|
||||
|
||||
|
||||
def install_kohya_ss_torch2():
|
||||
|
|
@ -132,11 +132,11 @@ def install_kohya_ss_torch2():
|
|||
# Upgrade pip if needed
|
||||
setup_common.install('--upgrade pip')
|
||||
|
||||
if setup_common.check_torch() == 1:
|
||||
input(
|
||||
f'{YELLOW}\nTorch 1 is already installed in the venv. To install Torch 2 delete the venv and re-run setup.bat\n\nHit any key to acknowledge.{RESET_COLOR}'
|
||||
)
|
||||
return
|
||||
# if setup_common.check_torch() == 1:
|
||||
# input(
|
||||
# f'{YELLOW}\nTorch 1 is already installed in the venv. To install Torch 2 delete the venv and re-run setup.bat\n\nHit any key to acknowledge.{RESET_COLOR}'
|
||||
# )
|
||||
# return
|
||||
|
||||
# setup_common.install(
|
||||
# 'torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118',
|
||||
|
|
@ -177,23 +177,24 @@ def main_menu():
|
|||
print('')
|
||||
|
||||
if choice == '1':
|
||||
while True:
|
||||
print('1. Torch 1 (legacy, no longer supported. Will be removed in v21.9.x)')
|
||||
print('2. Torch 2 (recommended)')
|
||||
print('3. Cancel')
|
||||
choice_torch = input('\nEnter your choice: ')
|
||||
print('')
|
||||
install_kohya_ss_torch2()
|
||||
# while True:
|
||||
# print('1. Torch 1 (legacy, no longer supported. Will be removed in v21.9.x)')
|
||||
# print('2. Torch 2 (recommended)')
|
||||
# print('3. Cancel')
|
||||
# choice_torch = input('\nEnter your choice: ')
|
||||
# print('')
|
||||
|
||||
if choice_torch == '1':
|
||||
install_kohya_ss_torch1()
|
||||
break
|
||||
elif choice_torch == '2':
|
||||
install_kohya_ss_torch2()
|
||||
break
|
||||
elif choice_torch == '3':
|
||||
break
|
||||
else:
|
||||
print('Invalid choice. Please enter a number between 1-3.')
|
||||
# if choice_torch == '1':
|
||||
# install_kohya_ss_torch1()
|
||||
# break
|
||||
# elif choice_torch == '2':
|
||||
# install_kohya_ss_torch2()
|
||||
# break
|
||||
# elif choice_torch == '3':
|
||||
# break
|
||||
# else:
|
||||
# print('Invalid choice. Please enter a number between 1-3.')
|
||||
elif choice == '2':
|
||||
cudann_install()
|
||||
elif choice == '3':
|
||||
|
|
|
|||
|
|
@ -35,12 +35,22 @@ def check_torch():
|
|||
'/opt/rocm/bin/rocminfo'
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
# Check if CUDA is available
|
||||
|
|
@ -48,10 +58,14 @@ def check_torch():
|
|||
log.warning('Torch reports CUDA not available')
|
||||
else:
|
||||
if torch.version.cuda:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX OneAPI version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
else:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
|
|
@ -62,9 +76,14 @@ def check_torch():
|
|||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
log.error(f'Could not load torch: {e}')
|
||||
|
|
@ -91,10 +110,7 @@ def main():
|
|||
if args.requirements:
|
||||
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
|
||||
else:
|
||||
if torch_ver == 1:
|
||||
setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=True)
|
||||
else:
|
||||
setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True)
|
||||
setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from library.utilities import utilities_tab
|
|||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
from localization_ext import add_javascript
|
||||
from library.localization_ext import add_javascript
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
|
|
@ -120,7 +120,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
|
|
@ -125,7 +125,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -98,8 +98,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
|
@ -245,7 +245,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,8 +78,8 @@ def train(args):
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
|
@ -177,7 +177,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -192,8 +192,8 @@ class NetworkTrainer:
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
|
@ -283,7 +283,10 @@ class NetworkTrainer:
|
|||
if args.dim_from_weights:
|
||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
# LyCORIS will work with this...
|
||||
if "dropout" not in net_kwargs:
|
||||
# workaround for LyCORIS (;^ω^)
|
||||
net_kwargs["dropout"] = args.network_dropout
|
||||
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
|
|
@ -342,7 +345,7 @@ class NetworkTrainer:
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
@ -954,7 +957,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -312,8 +312,8 @@ class TextualInversionTrainer:
|
|||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
|
|
@ -389,7 +389,7 @@ class TextualInversionTrainer:
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -236,8 +236,8 @@ def train(args):
|
|||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
|
|
@ -309,7 +309,7 @@ def train(args):
|
|||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue