SDXL CKPT generation, bump xformers
Bump xformers to latest version, 0.0.17 is a bit old. Add proper method to compile sdxl ckpt.401c60e
parent
3a2d5f281d
commit
915cc56042
|
|
@ -192,7 +192,7 @@
|
|||
"# Install Requirements\n",
|
||||
"!pip install -r /content/working/stable-diffusion-webui/requirements_versions.txt\n",
|
||||
"!pip install -r /content/working/stable-diffusion-webui./extensions/sd_dreambooth_extension/requirements.txt\n",
|
||||
"!pip install https://github.com/ArrowM/xformers/releases/download/xformers-0.0.17%2B36e23c5.d20230209-cp310-cu118/xformers-0.0.17+36e23c5.d20230209-cp310-cp310-linux_x86_64.whl\n",
|
||||
"!pip install xformers==0.0.21\n",
|
||||
"!pip install https://download.pytorch.org/whl/nightly/cu118/torch-2.0.0.dev20230209%2Bcu118-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-0.15.0.dev20230209%2Bcu118-cp310-cp310-linux_x86_64.whl\n"
|
||||
]
|
||||
},
|
||||
|
|
@ -208,16 +208,16 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qa5T38izv6CX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!%cd /content/working/stable-diffusion-webui/extensions/sd_dreambooth_extension/\n",
|
||||
"!git fetch && git pull\n",
|
||||
"!%cd /content/working/stable-diffusion-webui/ \n",
|
||||
"!python launch.py --share --xformers --enable-insecure-extension-access --torch2 --skip-install --skip-torch-cuda-test"
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,418 @@
|
|||
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
||||
# *Only* converts the UNet, VAE, and Text Encoder.
|
||||
# Does not convert optimizer state or any other thing.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import from_file
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.model_utils import unload_system_models, reload_system_models
|
||||
from dreambooth.utils.utils import printi
|
||||
from helpers import mytqdm
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
# the following are for sdxl
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(3):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
("norm_out", "conv_norm_out"),
|
||||
("mid.attn_1.", "mid_block.attentions.0."),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
# down_blocks have two resnets
|
||||
for j in range(2):
|
||||
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||
sd_downsample_prefix = f"down.{i}.downsample."
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
# the following are for SDXL
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
for sd_part, hf_part in vae_conversion_map_attn:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# =========================#
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("transformer.resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_openclip_text_enc_state_dict(text_enc_dict):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
or k.endswith(".self_attn.v_proj.weight")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.weight")]
|
||||
k_code = k[-len("q_proj.weight")]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.bias")
|
||||
or k.endswith(".self_attn.k_proj.bias")
|
||||
or k.endswith(".self_attn.v_proj.bias")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.bias")]
|
||||
k_code = k[-len("q_proj.bias")]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_openai_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_models: bool = True, log: bool = True,
|
||||
snap_rev: str = "", pbar: mytqdm = None):
|
||||
"""
|
||||
|
||||
@param model_name: The model name to compile
|
||||
@param reload_models: Whether to reload the system list of checkpoints.
|
||||
@param lora_file_name: The path to a lora pt file to merge with the unet. Auto set during training.
|
||||
@param log: Whether to print messages to console/UI.
|
||||
@param snap_rev: The revision of snapshot to load from
|
||||
@param pbar: progress bar
|
||||
@return: status: What happened, path: Checkpoint path
|
||||
"""
|
||||
unload_system_models()
|
||||
status.textinfo = "Compiling checkpoint."
|
||||
status.job_no = 0
|
||||
status.job_count = 7
|
||||
config = from_file(model_name)
|
||||
if lora_file_name is None and config.lora_model_name:
|
||||
lora_file_name = config.lora_model_name
|
||||
save_model_name = model_name if config.custom_model_name == "" else config.custom_model_name
|
||||
if config.custom_model_name == "":
|
||||
printi(f"Compiling checkpoint for {model_name}...", log=log)
|
||||
else:
|
||||
printi(f"Compiling checkpoint for {model_name} with a custom name {config.custom_model_name}", log=log)
|
||||
|
||||
if not model_name:
|
||||
msg = "Select a model to compile."
|
||||
print(msg)
|
||||
return msg
|
||||
|
||||
ckpt_dir = shared.ckpt_dir
|
||||
models_path = os.path.join(shared.models_path, "Stable-diffusion")
|
||||
if ckpt_dir is not None:
|
||||
models_path = ckpt_dir
|
||||
|
||||
save_safetensors = config.save_safetensors
|
||||
lora_diffusers = ""
|
||||
v2 = config.v2
|
||||
total_steps = config.revision
|
||||
if config.use_subdir:
|
||||
os.makedirs(os.path.join(models_path, save_model_name), exist_ok=True)
|
||||
models_path = os.path.join(models_path, save_model_name)
|
||||
checkpoint_ext = ".ckpt" if not config.save_safetensors else ".safetensors"
|
||||
checkpoint_path = os.path.join(models_path, f"{save_model_name}_{total_steps}{checkpoint_ext}")
|
||||
|
||||
model_path = config.get_pretrained_model_name_or_path()
|
||||
try:
|
||||
|
||||
# Path for safetensors
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
text_enc_2_path = osp.join(model_path, "text_encoder_2", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_2_path):
|
||||
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
|
||||
else:
|
||||
text_enc_2_path = osp.join(model_path, "text_encoder_2", "pytorch_model.bin")
|
||||
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
if config.half_model:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
printi(f"Saving checkpoint to {checkpoint_path}...", log=log)
|
||||
if save_safetensors:
|
||||
save_file(state_dict, checkpoint_path)
|
||||
else:
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
cfg_file = None
|
||||
new_name = os.path.join(config.model_dir, f"{config.model_name}.yaml")
|
||||
if os.path.exists(new_name):
|
||||
cfg_file = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"..",
|
||||
"configs",
|
||||
"SDXL-inference.yaml"
|
||||
)
|
||||
|
||||
if cfg_file is not None:
|
||||
cfg_dest = checkpoint_path.replace(checkpoint_ext, ".yaml")
|
||||
printi(f"Copying config file from {cfg_dest} to {cfg_dest}", log=log)
|
||||
shutil.copyfile(cfg_file, cfg_dest)
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Exception compiling checkpoint: {e}"
|
||||
print(msg)
|
||||
traceback.print_exc()
|
||||
return msg
|
||||
|
||||
try:
|
||||
del unet_state_dict
|
||||
del vae_state_dict
|
||||
del text_enc_path
|
||||
del state_dict
|
||||
if os.path.exists(lora_diffusers):
|
||||
shutil.rmtree(lora_diffusers, True)
|
||||
except:
|
||||
pass
|
||||
# cleanup()
|
||||
if reload_models:
|
||||
reload_system_models()
|
||||
msg = f"Checkpoint compiled successfully: {checkpoint_path}"
|
||||
printi(msg, log=log)
|
||||
return msg
|
||||
|
|
@ -582,7 +582,20 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
elif optimizer == "Adafactor":
|
||||
from transformers import Adafactor
|
||||
return Adafactor(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
elif optimizer == "Lion":
|
||||
from lion_pytorch import Lion
|
||||
return Lion(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from dreambooth.dataset.bucket_sampler import BucketSampler
|
|||
from dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from dreambooth.deis_velocity import get_velocity
|
||||
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
|
||||
from dreambooth.memory import find_executable_batch_size
|
||||
from dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
|
||||
from dreambooth.shared import status
|
||||
|
|
@ -355,7 +356,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.21. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
@ -1035,7 +1036,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
pbar2.set_description("Saving diffusion model")
|
||||
s_pipeline.save_pretrained(
|
||||
os.path.join(args.model_dir, "working"),
|
||||
safe_serialization=True,
|
||||
safe_serialization=False,
|
||||
)
|
||||
if ema_model is not None:
|
||||
ema_model.save_pretrained(
|
||||
|
|
@ -1043,7 +1044,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.get_pretrained_model_name_or_path(),
|
||||
"ema_unet",
|
||||
),
|
||||
safe_serialization=True,
|
||||
safe_serialization=False,
|
||||
)
|
||||
pbar2.update()
|
||||
|
||||
|
|
@ -1118,8 +1119,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if export_diffusers:
|
||||
copy_diffusion_model(args.model_name, os.path.join(user_model_dir, "diffusers"))
|
||||
else:
|
||||
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file,
|
||||
log=False, snap_rev=snap_rev, pbar=pbar2)
|
||||
if args.model_type == "SDXL":
|
||||
compile_checkpoint_xl(args.model_name, reload_models=False, lora_file_name=out_file,
|
||||
log=False, snap_rev=snap_rev, pbar=pbar2)
|
||||
else:
|
||||
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file,
|
||||
log=False, snap_rev=snap_rev, pbar=pbar2)
|
||||
printm("Restored, moved to acc.device.")
|
||||
pbar2.update()
|
||||
except Exception as ex:
|
||||
|
|
@ -1133,7 +1138,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
# Get the path to a temporary directory
|
||||
model_path = args.get_pretrained_model_name_or_path()
|
||||
tmp_dir = f"{model_path}_temp"
|
||||
s_pipeline.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
s_pipeline.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
del s_pipeline
|
||||
cleanup()
|
||||
if args.model_type == "SDXL":
|
||||
|
|
|
|||
|
|
@ -747,7 +747,7 @@ def main(args):
|
|||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.21. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -103,8 +103,8 @@ def xformers_check():
|
|||
import torch
|
||||
if version.Version(torch.__version__) < version.Version("1.12"):
|
||||
raise ValueError("PyTorch version must be >= 1.12")
|
||||
if version.Version(_xformers_version) < version.Version("0.0.17.dev"):
|
||||
raise ValueError("Xformers version must be >= 0.0.17.dev")
|
||||
if version.Version(_xformers_version) < version.Version("0.0.21"):
|
||||
raise ValueError("Xformers version must be >= 0.0.21")
|
||||
has_xformers = True
|
||||
except Exception as e:
|
||||
# print(f"Exception importing xformers: {e}")
|
||||
|
|
@ -128,7 +128,11 @@ def list_optimizer():
|
|||
optimizer_list.append("Lion")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import Adafactor
|
||||
optimizer_list.append("Adafactor")
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
from dadaptation import DAdaptAdam
|
||||
optimizer_list.append("AdamW Dadaptation")
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def check_xformers():
|
|||
"""
|
||||
try:
|
||||
xformers_version = importlib_metadata.version("xformers")
|
||||
xformers_outdated = Version(xformers_version) < Version("0.0.17.dev")
|
||||
xformers_outdated = Version(xformers_version) < Version("0.0.20")
|
||||
if xformers_outdated:
|
||||
try:
|
||||
torch_version = importlib_metadata.version("torch")
|
||||
|
|
@ -143,7 +143,7 @@ def check_versions():
|
|||
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
|
||||
|
||||
dependencies = [
|
||||
Dependency(module="xformers", version="0.0.17.dev", required=False),
|
||||
Dependency(module="xformers", version="0.0.21", required=False),
|
||||
Dependency(module="torch", version="1.13.1" if is_mac else "1.13.1+cu116"),
|
||||
Dependency(module="torchvision", version="0.14.1" if is_mac else "0.14.1+cu116"),
|
||||
Dependency(module="accelerate", version="0.17.1"),
|
||||
|
|
@ -207,7 +207,7 @@ def print_xformers_installation_error(err):
|
|||
print("# XFORMERS ISSUE DETECTED #")
|
||||
print("#######################################################################################################")
|
||||
print("#")
|
||||
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.17.dev built with torch {torch_ver})")
|
||||
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.21 built with torch {torch_ver})")
|
||||
print("# xformers will not be available for Dreambooth. Consider upgrading to Torch 2.")
|
||||
print("#")
|
||||
print("# Xformers installation exception:")
|
||||
|
|
@ -253,8 +253,8 @@ def check_torch_unsafe_load():
|
|||
|
||||
def print_xformers_torch1_instructions(xformers_version):
|
||||
print(f"# Your version of xformers is {xformers_version}.")
|
||||
print("# xformers >= 0.0.17.dev is required to be available on the Dreambooth tab.")
|
||||
print("# Torch 1 wheels of xformers >= 0.0.17.dev are no longer available on PyPI,")
|
||||
print("# xformers >= 0.0.20 is required to be available on the Dreambooth tab.")
|
||||
print("# Torch 1 wheels of xformers >= 0.0.20 are no longer available on PyPI,")
|
||||
print("# but you can manually download them by going to:")
|
||||
print("https://github.com/facebookresearch/xformers/actions")
|
||||
print("# Click on the most recent action tagged with a release (middle column).")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ try:
|
|||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from dreambooth.secret import get_secret
|
||||
from dreambooth.shared import DreamState
|
||||
from dreambooth.ui_functions import create_model, generate_samples, \
|
||||
|
|
@ -221,7 +222,10 @@ def dreambooth_api(_, app: FastAPI):
|
|||
global active
|
||||
shared.status.begin()
|
||||
active = True
|
||||
ckpt_result = compile_checkpoint(model_name, reload_models=False, log=False)
|
||||
if config.model_type == "SDXL":
|
||||
ckpt_result = compile_checkpoint_sdxl(model_name, reload_models=False, log=False)
|
||||
else:
|
||||
ckpt_result = compile_checkpoint(model_name, reload_models=False, log=False)
|
||||
active = False
|
||||
if "Checkpoint compiled successfully" in ckpt_result:
|
||||
path = ckpt_result.replace("Checkpoint compiled successfully:", "").strip()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import gradio as gr
|
|||
|
||||
from dreambooth.dataclasses.db_config import from_file, save_config
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from dreambooth.secret import (
|
||||
get_secret,
|
||||
create_secret,
|
||||
|
|
@ -229,7 +230,10 @@ def ui_gen_ckpt(model_name: str):
|
|||
printm("Config loaded")
|
||||
lora_path = config.lora_model_name
|
||||
print(f"Lora path: {lora_path}")
|
||||
res = compile_checkpoint(model_name, lora_path, True, True, config.snapshot)
|
||||
if config.model_type == "SDXL":
|
||||
res = compile_checkpoint_sdxl(model_name, lora_path, True, False, config.snapshot)
|
||||
else:
|
||||
res = compile_checkpoint(model_name, lora_path, True, True, config.snapshot)
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ set PYTHON=
|
|||
set GIT=
|
||||
set VENV_DIR=
|
||||
set COMMANDLINE_ARGS=
|
||||
set "XFORMERS_PACKAGE=xformers==0.0.17.dev447"
|
||||
set "XFORMERS_PACKAGE=xformers==0.0.21"
|
||||
:: Use the below argument if getting OOM extracting checkpoints
|
||||
:: set COMMANDLINE_ARGS=--ckptfix
|
||||
set "REQS_FILE=.\extensions\sd_dreambooth_extension\requirements.txt"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
|
||||
export COMMANDLINE_ARGS=""
|
||||
export XFORMERS_PACKAGE="xformers==0.0.17.dev447"
|
||||
export XFORMERS_PACKAGE="xformers==0.0.21"
|
||||
|
||||
# python3 executable
|
||||
#python_cmd="python3"
|
||||
|
|
|
|||
Loading…
Reference in New Issue