mirror of https://github.com/bmaltais/kohya_ss
commit
eec7d05e1c
|
|
@ -46,4 +46,4 @@ COPY --chown=appuser . .
|
|||
STOPSIGNAL SIGINT
|
||||
ENV LD_PRELOAD=libtcmalloc.so
|
||||
ENV PATH="$PATH:/home/appuser/.local/bin"
|
||||
CMD python3 "./kohya_gui.py" ${CLI_ARGS} --listen 0.0.0.0 --server_port 7860 --headless
|
||||
CMD python3 "./kohya_gui.py" ${CLI_ARGS} --listen 0.0.0.0 --server_port 7860
|
||||
|
|
|
|||
28
README.md
28
README.md
|
|
@ -98,6 +98,14 @@ The following limitations apply:
|
|||
If you run on Linux, there is an alternative docker container port with less limitations. You can find the project [here](https://github.com/P2Enjoy/kohya_ss-docker).
|
||||
|
||||
### Linux and macOS
|
||||
#### Linux pre-requirements
|
||||
|
||||
venv support need to be pre-installed. Can be done on ubuntu 22.04 with `apt install python3.10-venv`
|
||||
|
||||
Make sure to use a version of python >= 3.10.6 and < 3.11.0
|
||||
|
||||
#### Setup
|
||||
|
||||
In the terminal, run
|
||||
|
||||
```bash
|
||||
|
|
@ -345,6 +353,26 @@ This will store a backup file with your current locally installed pip packages a
|
|||
|
||||
## Change History
|
||||
|
||||
* 2023/06/14 (v21.7.8)
|
||||
- Add tkinter to dockerised version (thanks to @burdokow)
|
||||
- Add option to create caption files from folder names to the `group_images.py` tool.
|
||||
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
|
||||
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
|
||||
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
|
||||
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
|
||||
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
|
||||
- The following features have been added to the generation script.
|
||||
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
|
||||
- Added Variants similar to sd-dynamic-propmpts in the prompt.
|
||||
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
|
||||
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
|
||||
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
|
||||
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
|
||||
- It can also be specified for the prompt option.
|
||||
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
|
||||
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
|
||||
- There is no weighting function.
|
||||
- Add pre and posfix to wd14
|
||||
* 2023/06/12 (v21.7.7)
|
||||
- Add `Print only` button to all training tabs
|
||||
- Sort json file vars for easier visual search
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ services:
|
|||
environment:
|
||||
CLI_ARGS: ""
|
||||
SAFETENSORS_FAST_GPU: 1
|
||||
DISPLAY: $DISPLAY
|
||||
tmpfs:
|
||||
- /tmp
|
||||
volumes:
|
||||
|
|
@ -21,6 +22,7 @@ services:
|
|||
- ./.cache/config:/app/appuser/.config
|
||||
- ./.cache/nv:/home/appuser/.nv
|
||||
- ./.cache/keras:/home/appuser/.keras
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
|
|
|||
|
|
@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
|||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptLion : 引数は同上
|
||||
- DAdaptSGD : 引数は同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
|
||||
|
|
|
|||
|
|
@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
|||
- DAdaptAdam : 参数同上
|
||||
- DAdaptAdaGrad : 参数同上
|
||||
- DAdaptAdan : 参数同上
|
||||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptAdanIP : 参数同上
|
||||
- DAdaptLion : 参数同上
|
||||
- DAdaptSGD : 参数同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任何优化器
|
||||
|
||||
|
|
|
|||
56
fine_tune.py
56
fine_tune.py
|
|
@ -42,33 +42,37 @@ def train(args):
|
|||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
@ -393,7 +397,7 @@ def train(args):
|
|||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import glob
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
|
@ -11,6 +12,7 @@ import numpy as np
|
|||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder
|
||||
import library.train_util as train_util
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ VGG(
|
|||
)
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
|
|
@ -614,11 +615,15 @@ class PipelineLike:
|
|||
|
||||
# ControlNet
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
|
||||
def set_enable_control_net(self, en: bool):
|
||||
self.control_net_enabled = en
|
||||
|
||||
def replace_token(self, tokens, layer=None):
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
|
|
@ -1111,7 +1116,7 @@ class PipelineLike:
|
|||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets:
|
||||
if self.control_nets and self.control_net_enabled:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
|
|
@ -2159,6 +2164,110 @@ def preprocess_mask(mask):
|
|||
return mask
|
||||
|
||||
|
||||
# regular expression for dynamic prompt:
|
||||
# starts and ends with "{" and "}"
|
||||
# contains at least one variant divided by "|"
|
||||
# optional framgments divided by "$$" at start
|
||||
# if the first fragment is "E" or "e", enumerate all variants
|
||||
# if the second fragment is a number or two numbers, repeat the variants in the range
|
||||
# if the third fragment is a string, use it as a separator
|
||||
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
print(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0], 1)
|
||||
prompts.append(current)
|
||||
else:
|
||||
# if enumerating, iterate all combinations for previous prompts
|
||||
prompts = [prompt]
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement, 1))
|
||||
prompts = new_prompts
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
|
@ -2776,6 +2885,7 @@ def main(args):
|
|||
|
||||
# seed指定時はseedを決めておく
|
||||
if args.seed is not None:
|
||||
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
|
||||
random.seed(args.seed)
|
||||
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
|
||||
if len(predefined_seeds) == 1:
|
||||
|
|
@ -2827,6 +2937,8 @@ def main(args):
|
|||
ext.num_sub_prompts,
|
||||
)
|
||||
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
|
||||
|
||||
pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
|
|
@ -2870,6 +2982,9 @@ def main(args):
|
|||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
if args.highres_fix_disable_control_net:
|
||||
pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする
|
||||
|
||||
# このバッチの情報を取り出す
|
||||
(
|
||||
return_latents,
|
||||
|
|
@ -3058,121 +3173,138 @@ def main(args):
|
|||
while not valid:
|
||||
print("\nType prompt:")
|
||||
try:
|
||||
prompt = input()
|
||||
raw_prompt = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
valid = len(prompt.strip().split(" --")[0].strip()) > 0
|
||||
valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
|
||||
if not valid: # EOF, end app
|
||||
break
|
||||
else:
|
||||
prompt = prompt_list[prompt_index]
|
||||
raw_prompt = prompt_list[prompt_index]
|
||||
|
||||
# parse prompt
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
|
||||
prompt_args = prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
|
||||
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
if pi == 0 or len(raw_prompts) > 1:
|
||||
# parse prompt: if prompt is not changed, skip parsing
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
|
||||
if seeds is not None:
|
||||
# 数が足りないなら繰り返す
|
||||
if len(seeds) < args.images_per_prompt:
|
||||
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds)))
|
||||
seeds = seeds[: args.images_per_prompt]
|
||||
else:
|
||||
if predefined_seeds is not None:
|
||||
seeds = predefined_seeds[-args.images_per_prompt :]
|
||||
predefined_seeds = predefined_seeds[: -args.images_per_prompt]
|
||||
elif args.iter_same_seed:
|
||||
seeds = [iter_seed] * args.images_per_prompt
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
if len(seeds) > 0:
|
||||
seed = seeds.pop(0)
|
||||
else:
|
||||
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)]
|
||||
if args.interactive:
|
||||
print(f"seed: {seeds}")
|
||||
if predefined_seeds is not None:
|
||||
if len(predefined_seeds) > 0:
|
||||
seed = predefined_seeds.pop(0)
|
||||
else:
|
||||
print("predefined seeds are exhausted")
|
||||
seed = None
|
||||
elif args.iter_same_seed:
|
||||
seeds = iter_seed
|
||||
else:
|
||||
seed = None # 前のを消す
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 0x7FFFFFFF)
|
||||
if args.interactive:
|
||||
print(f"seed: {seed}")
|
||||
|
||||
# prepare init image, guide image and mask
|
||||
init_image = mask_image = guide_image = None
|
||||
|
||||
init_image = mask_image = guide_image = None
|
||||
for seed in seeds: # images_per_promptの数だけ
|
||||
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
|
||||
if init_images is not None:
|
||||
init_image = init_images[global_step % len(init_images)]
|
||||
|
|
@ -3454,6 +3586,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||
default=None,
|
||||
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_disable_control_net",
|
||||
action="store_true",
|
||||
help="disable ControlNet for highres fix / highres fixでControlNetを使わない",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
|
||||
|
|
|
|||
|
|
@ -800,6 +800,7 @@ def gradio_training(
|
|||
'DAdaptSGD',
|
||||
'Lion',
|
||||
'Lion8bit',
|
||||
'Prodigy',
|
||||
'SGDNesterov',
|
||||
'SGDNesterov8bit',
|
||||
],
|
||||
|
|
@ -1302,8 +1303,10 @@ def verify_image_folder_pattern(folder_path):
|
|||
]
|
||||
|
||||
# Check if all sub-folders match the pattern
|
||||
if len(matching_subfolders) != len(os.listdir(folder_path)):
|
||||
log.error(f"Not all images folders have proper name patterns in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ...")
|
||||
filenames = [filename for filename in os.listdir(folder_path) if not filename.startswith('.')]
|
||||
if len(matching_subfolders) != len(filenames):
|
||||
log.error(f"Not all image folders have proper name patterns <numbre>_<text> in {folder_path}. Please follow the folder structure documentation found at docs/image_folder_structure.md ...")
|
||||
log.error(f"Only folders are allowed in {folder_path}...")
|
||||
return False
|
||||
|
||||
# Check if no sub-folders exist
|
||||
|
|
@ -1312,4 +1315,4 @@ def verify_image_folder_pattern(folder_path):
|
|||
return False
|
||||
|
||||
log.info(f'Valid image folder names found in: {folder_path}')
|
||||
return True
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ def group_images(
|
|||
output_folder,
|
||||
group_size,
|
||||
include_subfolders,
|
||||
do_not_copy_other_files
|
||||
do_not_copy_other_files,
|
||||
generate_captions,
|
||||
caption_ext
|
||||
):
|
||||
if input_folder == '':
|
||||
msgbox('Input folder is missing...')
|
||||
|
|
@ -36,6 +38,10 @@ def group_images(
|
|||
run_cmd += f' --include_subfolders'
|
||||
if do_not_copy_other_files:
|
||||
run_cmd += f' --do_not_copy_other_files'
|
||||
if generate_captions:
|
||||
run_cmd += f' --caption'
|
||||
if caption_ext:
|
||||
run_cmd += f' --caption_ext={caption_ext}'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
|
|
@ -99,6 +105,19 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
value=False,
|
||||
info='Do not copy other files in the input folder to the output folder',
|
||||
)
|
||||
|
||||
generate_captions = gr.Checkbox(
|
||||
label='Generate Captions',
|
||||
value=False,
|
||||
info='Generate caption files for the grouped images based on their folder name',
|
||||
)
|
||||
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='Caption file extension (e.g., .txt)',
|
||||
value='.txt',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
group_images_button = gr.Button('Group images')
|
||||
|
||||
|
|
@ -109,7 +128,9 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
output_folder,
|
||||
group_size,
|
||||
include_subfolders,
|
||||
do_not_copy_other_files
|
||||
do_not_copy_other_files,
|
||||
generate_captions,
|
||||
caption_ext,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1518,6 +1518,76 @@ def glob_images_pathlib(dir_path, recursive):
|
|||
return image_paths
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
self.datasets = [self]
|
||||
self.batch_size = 1 # update in subclass
|
||||
|
||||
self.subsets = [self]
|
||||
self.num_repeats = 1 # update in subclass if needed
|
||||
self.img_count = 1 # update in subclass if needed
|
||||
self.bucket_info = {}
|
||||
self.is_reg = False
|
||||
self.image_dir = "dummy" # for metadata
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# override to avoid shuffling buckets
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""
|
||||
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
|
||||
|
||||
Returns: example like this:
|
||||
|
||||
for i in range(batch_size):
|
||||
image_key = ... # whatever hashable
|
||||
image_keys.append(image_key)
|
||||
|
||||
image = ... # PIL Image
|
||||
img_tensor = self.image_transforms(img)
|
||||
images.append(img_tensor)
|
||||
|
||||
caption = ... # str
|
||||
input_ids = self.get_input_ids(caption)
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
input_ids_list = torch.stack(input_ids_list, dim=0)
|
||||
example = {
|
||||
"images": images,
|
||||
"input_ids": input_ids_list,
|
||||
"captions": captions, # for debug_dataset
|
||||
"latents": None,
|
||||
"image_keys": image_keys, # for debug_dataset
|
||||
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
|
||||
}
|
||||
return example
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
dataset_class = args.dataset_class.split(".")[-1]
|
||||
module = importlib.import_module(module)
|
||||
dataset_class = getattr(module, dataset_class)
|
||||
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
|
|
@ -2394,7 +2464,6 @@ def add_dataset_arguments(
|
|||
default=1,
|
||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
|
|
@ -2402,6 +2471,13 @@ def add_dataset_arguments(
|
|||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_class",
|
||||
type=str,
|
||||
default=None,
|
||||
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
|
||||
)
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
|
|
@ -2676,15 +2752,7 @@ def get_optimizer(args, trainable_params):
|
|||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
|
||||
# check lr and lr_count, and print warning
|
||||
actual_lr = lr
|
||||
lr_count = 1
|
||||
|
|
@ -2697,40 +2765,60 @@ def get_optimizer(args, trainable_params):
|
|||
|
||||
if actual_lr <= 0.1:
|
||||
print(
|
||||
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
|
||||
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
|
||||
)
|
||||
print("recommend option: lr=1.0 / 推奨は1.0です")
|
||||
if lr_count > 1:
|
||||
print(
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
)
|
||||
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
if optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
else:
|
||||
# Prodigy
|
||||
# check Prodigy is installed
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
|
||||
|
||||
print(f"use Prodigy optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
from .common_gui import get_folder_path
|
||||
from .common_gui import get_folder_path, add_pre_postfix
|
||||
import os
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
|
|
@ -23,6 +23,8 @@ def caption_images(
|
|||
debug,
|
||||
undesired_tags,
|
||||
frequency_tags,
|
||||
prefix,
|
||||
postfix,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == '':
|
||||
|
|
@ -65,6 +67,14 @@ def caption_images(
|
|||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
# Add prefix and postfix
|
||||
add_pre_postfix(
|
||||
folder=train_data_dir,
|
||||
caption_file_ext=caption_extension,
|
||||
prefix=prefix,
|
||||
postfix=postfix,
|
||||
)
|
||||
|
||||
log.info('...captioning done')
|
||||
|
||||
|
||||
|
|
@ -109,6 +119,19 @@ def gradio_wd14_caption_gui_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
prefix = gr.Textbox(
|
||||
label='Prefix to add to WD14 caption',
|
||||
placeholder='(Optional)',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
postfix = gr.Textbox(
|
||||
label='Postfix to add to WD14 caption',
|
||||
placeholder='(Optional)',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
replace_underscores = gr.Checkbox(
|
||||
label='Replace underscores in filenames with spaces',
|
||||
|
|
@ -189,6 +212,8 @@ def gradio_wd14_caption_gui_tab(headless=False):
|
|||
debug,
|
||||
undesired_tags,
|
||||
frequency_tags,
|
||||
prefix,
|
||||
postfix,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
accelerate==0.15.0
|
||||
accelerate==0.19.0
|
||||
albumentations==1.3.0
|
||||
altair==4.2.2
|
||||
bitsandbytes==0.35.0
|
||||
|
|
@ -10,11 +10,12 @@ fairscale==0.4.13
|
|||
ftfy==6.1.1
|
||||
gradio==3.23.0; sys_platform == 'darwin'
|
||||
gradio==3.32.0; sys_platform != 'darwin'
|
||||
huggingface-hub==0.13.0; sys_platform == 'darwin'
|
||||
huggingface-hub==0.13.3; sys_platform == 'darwin'
|
||||
huggingface-hub==0.13.3; sys_platform != 'darwin'
|
||||
lion-pytorch==0.0.6
|
||||
lycoris_lora==0.1.6
|
||||
opencv-python==4.7.0.68
|
||||
prodigyopt==1.0
|
||||
pytorch-lightning==1.9.0
|
||||
rich==13.4.1
|
||||
safetensors==0.2.6
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ huggingface-hub==0.13.3
|
|||
lion-pytorch==0.0.6
|
||||
lycoris_lora==0.1.6
|
||||
opencv-python==4.7.0.68
|
||||
prodigyopt==1.0
|
||||
pytorch-lightning==1.9.0
|
||||
rich==13.4.1
|
||||
safetensors==0.2.6
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ huggingface-hub==0.15.1
|
|||
lion-pytorch==0.0.6
|
||||
lycoris_lora==0.1.6
|
||||
opencv-python==4.7.0.68
|
||||
prodigyopt==1.0
|
||||
pytorch-lightning==1.9.0
|
||||
rich==13.4.1
|
||||
safetensors==0.2.6
|
||||
|
|
|
|||
|
|
@ -25,3 +25,6 @@ python .\tools\check_local_modules.py
|
|||
call .\venv\Scripts\activate.bat
|
||||
|
||||
python .\tools\setup_windows.py
|
||||
|
||||
:: Deactivate the virtual environment
|
||||
call .\venv\Scripts\deactivate.bat
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
|
||||
# Check if Python version meets the recommended version
|
||||
$pythonVersion = & .\venv\Scripts\python.exe --version 2>$null
|
||||
if ($pythonVersion -notmatch "^Python $PYTHON_VER") {
|
||||
Write-Host "Warning: Python version $PYTHON_VER is recommended."
|
||||
}
|
||||
|
||||
if (-not (Test-Path -Path "venv")) {
|
||||
Write-Host "Creating venv..."
|
||||
python -m venv venv
|
||||
}
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
$null = New-Item -ItemType Directory -Force -Path ".\logs\setup"
|
||||
|
||||
# Deactivate the virtual environment
|
||||
& .\venv\Scripts\deactivate.bat
|
||||
|
||||
# Calling external python program to check for local modules
|
||||
& .\venv\Scripts\python.exe .\tools\check_local_modules.py
|
||||
|
||||
& .\venv\Scripts\activate.bat
|
||||
|
||||
& .\venv\Scripts\python.exe .\tools\setup_windows.py
|
||||
|
||||
# Deactivate the virtual environment
|
||||
& .\venv\Scripts\deactivate.bat
|
||||
7
setup.sh
7
setup.sh
|
|
@ -240,10 +240,9 @@ install_python_dependencies() {
|
|||
echo "Installing python dependencies. This could take a few minutes as it downloads files."
|
||||
echo "If this operation ever runs too long, you can rerun this script in verbose mode to check."
|
||||
case "$OSTYPE" in
|
||||
"linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116 >&3 &&
|
||||
pip install -U -I --no-deps \
|
||||
https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;;
|
||||
"linux-gnu"*) pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118 >&3 &&
|
||||
pip install -U -I xformers==0.0.20 >&3 ;;
|
||||
"darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \
|
||||
-f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;;
|
||||
"cygwin")
|
||||
|
|
|
|||
|
|
@ -11,13 +11,15 @@ log = setup_logging()
|
|||
|
||||
class ImageProcessor:
|
||||
|
||||
def __init__(self, input_folder, output_folder, group_size, include_subfolders, do_not_copy_other_files, pad):
|
||||
def __init__(self, input_folder, output_folder, group_size, include_subfolders, do_not_copy_other_files, pad, caption, caption_ext):
|
||||
self.input_folder = input_folder
|
||||
self.output_folder = output_folder
|
||||
self.group_size = group_size
|
||||
self.include_subfolders = include_subfolders
|
||||
self.do_not_copy_other_files = do_not_copy_other_files
|
||||
self.pad = pad
|
||||
self.caption = caption
|
||||
self.caption_ext = caption_ext
|
||||
self.image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.webp')
|
||||
|
||||
def get_image_paths(self):
|
||||
|
|
@ -40,8 +42,12 @@ class ImageProcessor:
|
|||
if len(group) > 0:
|
||||
aspect_ratios = self.get_aspect_ratios(group)
|
||||
avg_aspect_ratio = np.mean(aspect_ratios)
|
||||
cropped_images = self.crop_images(group, avg_aspect_ratio)
|
||||
self.resize_and_save_images(cropped_images, group_index)
|
||||
if self.pad:
|
||||
padded_images = self.pad_images(group, avg_aspect_ratio)
|
||||
self.resize_and_save_images(padded_images, group_index, group)
|
||||
else:
|
||||
cropped_images = self.crop_images(group, avg_aspect_ratio)
|
||||
self.resize_and_save_images(cropped_images, group_index, group)
|
||||
if not self.do_not_copy_other_files:
|
||||
self.copy_other_files(group, group_index)
|
||||
|
||||
|
|
@ -78,15 +84,29 @@ class ImageProcessor:
|
|||
img = img.crop((0, top, img.width, bottom))
|
||||
return img
|
||||
|
||||
def resize_and_save_images(self, cropped_images, group_index):
|
||||
def resize_and_save_images(self, cropped_images, group_index, source_paths):
|
||||
max_width = max(img.width for img in cropped_images)
|
||||
max_height = max(img.height for img in cropped_images)
|
||||
for j, img in enumerate(cropped_images):
|
||||
img = img.resize((max_width, max_height))
|
||||
os.makedirs(self.output_folder, exist_ok=True)
|
||||
output_path = os.path.join(self.output_folder, f"group-{group_index+1}-image-{j+1}.jpg")
|
||||
original_filename = os.path.basename(source_paths[j])
|
||||
filename_without_ext = os.path.splitext(original_filename)[0]
|
||||
output_path = os.path.join(self.output_folder, f"group-{group_index+1}-{filename_without_ext}.jpg")
|
||||
log.info(f" Saving processed image to {output_path}")
|
||||
img.convert('RGB').save(output_path)
|
||||
|
||||
if self.caption:
|
||||
self.create_caption_file(source_paths[j], group_index, filename_without_ext)
|
||||
|
||||
def create_caption_file(self, source_path, group_index, caption_filename):
|
||||
dirpath = os.path.dirname(source_path)
|
||||
caption = os.path.basename(dirpath).split('_')[-1]
|
||||
caption_filename = caption_filename + self.caption_ext
|
||||
caption_path = os.path.join(self.output_folder, f"group-{group_index+1}-{caption_filename}")
|
||||
with open(caption_path, 'w') as f:
|
||||
f.write(caption)
|
||||
|
||||
|
||||
def copy_other_files(self, group, group_index):
|
||||
for j, path in enumerate(group):
|
||||
|
|
@ -112,10 +132,10 @@ class ImageProcessor:
|
|||
avg_aspect_ratio = np.mean(aspect_ratios)
|
||||
if self.pad:
|
||||
padded_images = self.pad_images(group, avg_aspect_ratio)
|
||||
self.resize_and_save_images(padded_images, group_index)
|
||||
self.resize_and_save_images(padded_images, group_index, group)
|
||||
else:
|
||||
cropped_images = self.crop_images(group, avg_aspect_ratio)
|
||||
self.resize_and_save_images(cropped_images, group_index)
|
||||
self.resize_and_save_images(cropped_images, group_index, group)
|
||||
if not self.do_not_copy_other_files:
|
||||
self.copy_other_files(group, group_index)
|
||||
|
||||
|
|
@ -150,10 +170,12 @@ def main():
|
|||
parser.add_argument('--include_subfolders', action='store_true', help='Include subfolders in search for images')
|
||||
parser.add_argument('--do_not_copy_other_files', '--no_copy', dest='do_not_copy_other_files', action='store_true', help='Do not copy other files with the same name as images')
|
||||
parser.add_argument('--pad', action='store_true', help='Pad images instead of cropping them')
|
||||
parser.add_argument('--caption', action='store_true', help='Create a caption file for each image')
|
||||
parser.add_argument('--caption_ext', type=str, default='.txt', help='Extension for the caption file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
processor = ImageProcessor(args.input_folder, args.output_folder, args.group_size, args.include_subfolders, args.do_not_copy_other_files, args.pad)
|
||||
processor = ImageProcessor(args.input_folder, args.output_folder, args.group_size, args.include_subfolders, args.do_not_copy_other_files, args.pad, args.caption, args.caption_ext)
|
||||
processor.process_images()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,8 +1,21 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
from setup_windows import install
|
||||
from setup_windows import install, check_repo_version
|
||||
|
||||
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
||||
project_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Check if the "tools" directory is present in the project_directory
|
||||
if "tools" in project_directory:
|
||||
# If the "tools" directory is present, move one level up to the parent directory
|
||||
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Add the project directory to the beginning of the Python search path
|
||||
sys.path.insert(0, project_directory)
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
|
|
@ -74,10 +87,14 @@ def install_requirements(requirements_file):
|
|||
|
||||
# Iterate over each line and install the requirements
|
||||
for line in lines:
|
||||
install(line)
|
||||
# Remove brackets and their contents from the line using regular expressions
|
||||
# eg diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
||||
package_name = re.sub(r'\[.*?\]', '', line)
|
||||
install(line, package_name)
|
||||
|
||||
|
||||
def main():
|
||||
check_repo_version()
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate that requirements are satisfied.'
|
||||
|
|
@ -91,11 +108,14 @@ def main():
|
|||
parser.add_argument('--debug', action='store_true', help='Debug on')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check Torch
|
||||
if check_torch() == 1:
|
||||
install_requirements('requirements_windows_torch1.txt')
|
||||
if not args.requirements:
|
||||
# Check Torch
|
||||
if check_torch() == 1:
|
||||
install_requirements('requirements_windows_torch1.txt')
|
||||
else:
|
||||
install_requirements('requirements_windows_torch2.txt')
|
||||
else:
|
||||
install_requirements('requirements_windows_torch2.txt')
|
||||
install_requirements(args.requirements)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
42
train_db.py
42
train_db.py
|
|
@ -46,26 +46,30 @@ def train(args):
|
|||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
@ -380,7 +384,7 @@ def train(args):
|
|||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def generate_step_logs(
|
|||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet.
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
|
|
@ -67,7 +67,7 @@ def generate_step_logs(
|
|||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
|
@ -92,42 +92,50 @@ def train(args):
|
|||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
@ -185,6 +193,7 @@ def train(args):
|
|||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
|
|
|||
|
|
@ -153,43 +153,46 @@ def train(args):
|
|||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
|
@ -473,7 +476,7 @@ def train(args):
|
|||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,13 @@ from library.config_util import (
|
|||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
imagenet_templates_small = [
|
||||
|
|
@ -88,6 +94,9 @@ def train(args):
|
|||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
assert (
|
||||
args.dataset_class is None
|
||||
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
|
|
@ -506,7 +515,7 @@ def train(args):
|
|||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue