mirror of https://github.com/bmaltais/kohya_ss
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
commit
4703c7fa83
|
|
@ -1005,6 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
|
@ -9,6 +9,57 @@ from library import sdxl_original_unet
|
|||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
||||
|
||||
# Diffusersの設定を読み込むための参照モデル
|
||||
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-0.9" # アクセス権が必要
|
||||
|
||||
DIFFUSERS_SDXL_UNET_CONFIG = {
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [5, 10, 20],
|
||||
"block_out_channels": [320, 640, 1280],
|
||||
"center_input_sample": False,
|
||||
"class_embed_type": None,
|
||||
"class_embeddings_concat": False,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 2048,
|
||||
"cross_attention_norm": None,
|
||||
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": False,
|
||||
"encoder_hid_dim": None,
|
||||
"encoder_hid_dim_type": None,
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": None,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": None,
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": False,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"sample_size": 128,
|
||||
"time_cond_proj_dim": None,
|
||||
"time_embedding_act_fn": None,
|
||||
"time_embedding_dim": None,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": None,
|
||||
"transformer_layers_per_block": [1, 2, 10],
|
||||
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
||||
"upcast_attention": False,
|
||||
"use_linear_projection": True,
|
||||
}
|
||||
|
||||
|
||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||
|
|
@ -119,7 +170,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||
# Text Encoders
|
||||
print("building text encoders")
|
||||
|
||||
# Text Encoder 1 is same to SDXL
|
||||
# Text Encoder 1 is same to Stability AI's SDXL
|
||||
text_model1_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
|
|
@ -143,7 +194,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||
)
|
||||
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
||||
|
||||
# Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
|
|
@ -198,6 +249,122 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def make_unet_conversion_map():
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
|
||||
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
||||
return convert_unet_state_dict(du_sd, conversion_map)
|
||||
|
||||
|
||||
def convert_unet_state_dict(src_sd, conversion_map):
|
||||
converted_sd = {}
|
||||
for src_key, value in src_sd.items():
|
||||
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
|
||||
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
|
||||
while len(src_key_fragments) > 0:
|
||||
src_key_prefix = ".".join(src_key_fragments) + "."
|
||||
if src_key_prefix in conversion_map:
|
||||
converted_prefix = conversion_map[src_key_prefix]
|
||||
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
||||
converted_sd[converted_key] = value
|
||||
break
|
||||
src_key_fragments.pop(-1)
|
||||
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
||||
|
||||
return converted_sd
|
||||
|
||||
|
||||
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
|
||||
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
||||
return convert_unet_state_dict(sd, conversion_dict)
|
||||
|
||||
|
||||
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
|
|
@ -314,3 +481,53 @@ def save_stable_diffusion_checkpoint(
|
|||
torch.save(new_ckpt, output_file)
|
||||
|
||||
return key_count
|
||||
|
||||
|
||||
def save_diffusers_checkpoint(
|
||||
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
||||
):
|
||||
# convert U-Net
|
||||
unet_sd = unet.state_dict()
|
||||
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
||||
|
||||
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
||||
if save_dtype is not None:
|
||||
diffusers_unet.to(save_dtype)
|
||||
diffusers_unet.load_state_dict(du_unet_sd)
|
||||
|
||||
# create pipeline to save
|
||||
if pretrained_model_name_or_path is None:
|
||||
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
||||
if vae is None:
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||
|
||||
# prevent local path from being saved
|
||||
def remove_name_or_path(model):
|
||||
if hasattr(model, "config"):
|
||||
model.config._name_or_path = None
|
||||
model.config._name_or_path = None
|
||||
|
||||
remove_name_or_path(diffusers_unet)
|
||||
remove_name_or_path(text_encoder1)
|
||||
remove_name_or_path(text_encoder2)
|
||||
remove_name_or_path(scheduler)
|
||||
remove_name_or_path(tokenizer1)
|
||||
remove_name_or_path(tokenizer2)
|
||||
remove_name_or_path(vae)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
unet=diffusers_unet,
|
||||
text_encoder=text_encoder1,
|
||||
text_encoder_2=text_encoder2,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer1,
|
||||
tokenizer_2=tokenizer2,
|
||||
)
|
||||
if save_dtype is not None:
|
||||
pipeline.to(None, save_dtype)
|
||||
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
||||
|
|
|
|||
|
|
@ -734,24 +734,6 @@ class Transformer2DModel(nn.Module):
|
|||
|
||||
return output
|
||||
|
||||
def forward_xxx(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("Transformer2DModel: Using gradient checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.forward_body), hidden_states, encoder_hidden_states, timestep
|
||||
)
|
||||
else:
|
||||
output = self.forward_body(hidden_states, encoder_hidden_states, timestep)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
def __init__(self, channels, out_channels):
|
||||
|
|
|
|||
|
|
@ -8,10 +8,12 @@ import torch
|
|||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
import open_clip
|
||||
from library import model_util, sdxl_model_util, train_util
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
|
@ -50,23 +52,54 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||
|
||||
|
||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||
# only supports StableDiffusion
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
assert (
|
||||
load_stable_diffusion_format
|
||||
), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}"
|
||||
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||
if load_stable_diffusion_format:
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
try:
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None)
|
||||
except EnvironmentError as ex:
|
||||
if variant is not None:
|
||||
print("try to load fp32 model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||
else:
|
||||
raise ex
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
|
||||
text_encoder1 = pipe.text_encoder
|
||||
text_encoder2 = pipe.text_encoder_2
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# Diffusers U-Net to original U-Net
|
||||
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
original_unet.load_state_dict(state_dict)
|
||||
unet = original_unet
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
ckpt_info = None
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
|
|
@ -76,101 +109,32 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
|
||||
def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
|
||||
if padding == "max_length":
|
||||
# for training
|
||||
assert max_length is not None
|
||||
assert truncation == True
|
||||
assert return_tensors == "pt"
|
||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for weighted prompt
|
||||
assert isinstance(text, str), f"input must be str: {text}"
|
||||
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
|
||||
|
||||
# find eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero().max()
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for Textual Inversion
|
||||
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
assert not add_special_tokens
|
||||
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||
return input_ids
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
token = token.lower()
|
||||
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||
tokens_to_add.append(token)
|
||||
|
||||
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||
for token in tokens_to_add:
|
||||
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||
|
||||
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return open_clip.tokenizer._tokenizer.vocab_size
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
original_path = TOKENIZER_PATH
|
||||
|
||||
tokenizer1: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
for original_path in original_paths:
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer1 is None:
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer1.save_pretrained(local_tokenizer_path)
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
tokeniers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
# tokenizer2 is from open_clip
|
||||
# TODO caching
|
||||
tokenizer2 = WrapperTokenizer()
|
||||
|
||||
return [tokenizer1, tokenizer2]
|
||||
return tokeniers
|
||||
|
||||
|
||||
def get_hidden_states(
|
||||
|
|
@ -296,7 +260,16 @@ def save_sd_model_on_train_end(
|
|||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
raise NotImplementedError("diffusers_saver is not implemented")
|
||||
sdxl_model_util.save_diffusers_checkpoint(
|
||||
out_dir,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
src_path,
|
||||
vae,
|
||||
use_safetensors=use_safetensors,
|
||||
save_dtype=save_dtype,
|
||||
)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(
|
||||
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||
|
|
@ -338,7 +311,16 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
raise NotImplementedError("diffusers_saver is not implemented")
|
||||
sdxl_model_util.save_diffusers_checkpoint(
|
||||
out_dir,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
src_path,
|
||||
vae,
|
||||
use_safetensors=use_safetensors,
|
||||
save_dtype=save_dtype,
|
||||
)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
accelerate==0.19.0
|
||||
albumentations==1.3.0
|
||||
altair==4.2.2
|
||||
dadaptation==3.1
|
||||
diffusers[torch]==0.17.1
|
||||
diffusers[torch]==0.18.2
|
||||
easygui==0.98.3
|
||||
einops==0.6.0
|
||||
fairscale==0.4.13
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
||||
xformers==0.0.20 bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorboard==2.12.1 tensorflow==2.12.0
|
||||
tensorboard==2.12.1 tensorflow==2.12.0
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorflow-macos tensorboard==2.12.1
|
||||
tensorflow-macos tensorboard==2.12.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorflow-metal tensorboard==2.12.1
|
||||
tensorflow-metal tensorboard==2.12.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
||||
xformers==0.0.20 bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorboard==2.12.1 tensorflow==2.12.0 wheel
|
||||
tensorboard==2.12.1 tensorflow==2.12.0 wheel
|
||||
tensorrt
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116 # no_verify
|
||||
https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps # no_verify
|
||||
bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorboard==2.10.1 tensorflow==2.10.1
|
||||
tensorboard==2.10.1 tensorflow==2.10.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 # no_verify
|
||||
xformers==0.0.20 bitsandbytes==0.35.0
|
||||
accelerate==0.19.0 tensorboard==2.12.3 tensorflow==2.12.0
|
||||
tensorboard==2.12.3 tensorflow==2.12.0
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1605,18 +1605,14 @@ def main(args):
|
|||
num_vectors_per_token = embeds1.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
|
||||
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||
# TODO make random alphabet string
|
||||
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||
|
||||
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||
f"tokenizer has same word to token string (filename): {embeds_file}"
|
||||
+ f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}"
|
||||
)
|
||||
|
||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def train(args):
|
|||
# set_diffusers_xformers_flag(unet, True)
|
||||
set_diffusers_xformers_flag(vae, True)
|
||||
else:
|
||||
# Windows版のxformersはfloatで学習できなかったりxformersを使わない設定も可能にしておく必要がある
|
||||
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
|
||||
accelerator.print("Disable Diffusers' xformers")
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
|
@ -271,7 +271,7 @@ def train(args):
|
|||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
|
|
|
|||
|
|
@ -39,18 +39,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
# tokenizer 1 is seems to be ok
|
||||
|
||||
# count words for token string: regular expression from open_clip
|
||||
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
words = regex.findall(pat, token_string)
|
||||
word_count = len(words)
|
||||
assert word_count == 1, (
|
||||
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
|
||||
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
|
||||
)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
|
|
@ -92,7 +93,7 @@ class TextualInversionTrainer:
|
|||
tokenizer = train_util.load_tokenizer(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
||||
pass
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
|
|
@ -200,19 +201,13 @@ class TextualInversionTrainer:
|
|||
init_token_ids_list = [None] * len(tokenizers)
|
||||
|
||||
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
|
||||
# token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
|
||||
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
|
||||
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
|
||||
|
||||
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
|
||||
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
|
||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
|
||||
token_strings = [args.token_string] + [
|
||||
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
|
||||
]
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
|
|
|
|||
Loading…
Reference in New Issue