Fixes and Refactoring and Cleanup, Oh My...

Remove annoying message from bnb import, fix missing os import.
Move db_concept to it's own thing so we can preview prompts before training.
Fix dumb issues when loading/saving configs.
Add setting for saving class txt prompts, versus just doing it.
V2 conversion fixes.
Move methods not specifically related to dreambooth training to utils class.
Add automatic setter for UI details when switching models.
Loading model params won't overwrite v2/ema/revision settings.
Cleanup installer, better logging, etc.
Use github diffusers version, for now.
pull/422/head
d8ahazard 2022-12-05 17:50:24 -06:00
parent 48f5811792
commit 9f578527e3
17 changed files with 992 additions and 801 deletions

0
__init__.py Normal file
View File

View File

@ -111,12 +111,7 @@ def get_compute_capability(cuda):
def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
if os.name == "NT":
if os.name == "nt":
print("Using magick windows DLL!")
return "libbitsandbytes_cudaall.dll" # $$$

View File

@ -0,0 +1,126 @@
import errno
import os
from pathlib import Path
from typing import Set, Union
from warnings import warn
from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so" if os.name != "nt" else "cudart64_110.dll"
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
existent_directories: Set[Path] = set()
for path in candidate_paths:
try:
if path.exists():
existent_directories.add(path)
except OSError as exc:
if exc.errno != errno.ENAMETOOLONG:
raise exc
non_existent_directories: Set[Path] = candidate_paths - existent_directories
if non_existent_directories:
warn(
"WARNING: The following directories listed in your path were found to "
f"be non-existent: {non_existent_directories}"
)
return existent_directories
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
return {
path / CUDA_RUNTIME_LIB
for path in candidate_paths
if (path / CUDA_RUNTIME_LIB).is_file()
}
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
"""
Searches a given environmental var for the CUDA runtime library,
i.e. `libcudart.so`.
"""
return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))
def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
return get_cuda_runtime_lib_paths(
resolve_paths_list(paths_list_candidate)
)
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1:
warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env."
)
warn(warning_msg)
def determine_cuda_runtime_lib_path() -> Union[Path, None]:
"""
Searches for a cuda installations, in the following order of priority:
1. active conda env
2. LD_LIBRARY_PATH
3. any other env vars, while ignoring those that
- are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
- don't contain the path separator `/`
If multiple libraries are found in part 3, we optimistically try one,
while giving a warning message.
"""
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
if "CONDA_PREFIX" in candidate_env_vars:
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
warn_in_case_of_duplicates(conda_cuda_libs)
if conda_cuda_libs:
return next(iter(conda_cuda_libs))
warn(
f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
)
if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
if lib_ld_cuda_libs:
return next(iter(lib_ld_cuda_libs))
warn_in_case_of_duplicates(lib_ld_cuda_libs)
warn(
f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
)
remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items()
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
}
cuda_runtime_libs = set()
for env_var, value in remaining_candidate_env_vars.items():
cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0:
print('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs)
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None

View File

@ -7,8 +7,8 @@ from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from extensions.sd_dreambooth_extension.dreambooth.db_config import Concept
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import is_image, list_features
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter
from modules import images

0
dreambooth/__init__.py Normal file
View File

80
dreambooth/db_concept.py Normal file
View File

@ -0,0 +1,80 @@
import os
class Concept(dict):
def __init__(self, max_steps: int = -1, instance_data_dir: str = "", class_data_dir: str = "",
file_prompt_contents: str = "Description", instance_prompt: str = "", class_prompt: str = "",
save_sample_prompt: str = "", save_sample_template: str = "", instance_token: str = "",
class_token: str = "", num_class_images: int = 0, class_negative_prompt: str = "",
class_guidance_scale: float = 7.5, class_infer_steps: int = 60, save_sample_negative_prompt: str = "",
n_save_sample: int = 1, sample_seed: int = -1, save_guidance_scale: float = 7.5,
save_infer_steps: int = 60, input_dict=None):
if input_dict is None:
self.max_steps = max_steps
self.instance_data_dir = instance_data_dir
self.class_data_dir = class_data_dir
self.file_prompt_contents = file_prompt_contents
self.instance_prompt = instance_prompt
self.class_prompt = class_prompt
self.save_sample_prompt = save_sample_prompt
self.save_sample_template = save_sample_template
self.instance_token = instance_token
self.class_token = class_token
self.num_class_images = num_class_images
self.class_negative_prompt = class_negative_prompt
self.class_guidance_scale = class_guidance_scale
self.class_infer_steps = class_infer_steps
self.save_sample_negative_prompt = save_sample_negative_prompt
self.n_save_sample = n_save_sample
self.sample_seed = sample_seed
self.save_guidance_scale = save_guidance_scale
self.save_infer_steps = save_infer_steps
else:
self.max_steps = input_dict["max_steps"] if "max_steps" in input_dict else -1
self.instance_data_dir = input_dict["instance_data_dir"] if "instance_data_dir" in input_dict else ""
self.class_data_dir = input_dict["class_data_dir"] if "class_data_dir" in input_dict else ""
self.file_prompt_contents = input_dict["file_prompt_contents"] if "file_prompt_contents" in input_dict else "Description"
self.instance_prompt = input_dict["instance_prompt"] if "instance_prompt" in input_dict else ""
self.class_prompt = input_dict["class_prompt"] if "class_prompt" in input_dict else ""
self.save_sample_prompt = input_dict["save_sample_prompt"] if "save_sample_prompt" in input_dict else ""
self.save_sample_template = input_dict["save_sample_template"] if "save_sample_template" in input_dict else ""
self.instance_token = input_dict["instance_token"] if "instance_token" in input_dict else ""
self.class_token = input_dict["class_token"] if "class_token" in input_dict else ""
self.num_class_images = input_dict["num_class_images"] if "num_class_images" in input_dict else 0
self.class_negative_prompt = input_dict["class_negative_prompt"] if "class_negative_promt" in input_dict else ""
self.class_guidance_scale = input_dict["class_guidance_scale"] if "class_guidance_scale" in input_dict else 7.5
self.class_infer_steps = input_dict["class_infer_steps"] if "class_infer_steps" in input_dict else 60
self.save_sample_negative_prompt = input_dict["save_sample_negative_prompt"] if "save_sample_negative_prompt" in input_dict else ""
self.n_save_sample = input_dict["n_save_sample"] if "n_save_samples" in input_dict else 1
self.sample_seed = input_dict["sample_seed"] if "sample_seed" in input_dict else -1
self.save_guidance_scale = input_dict["save_guidance_scale"] if "save_guidance_scale" in input_dict else 7.5
self.save_infer_steps = input_dict["save_infer_steps"] if "save_infer_steps" in input_dict else 60
self_dict = {
"max_steps": self.max_steps,
"instance_data_dir": self.instance_data_dir,
"class_data_dir": self.class_data_dir,
"file_prompt_contents": self.file_prompt_contents,
"instance_prompt": self.instance_prompt,
"class_prompt": self.class_prompt,
"save_sample_prompt": self.save_sample_prompt,
"save_sample_template": self.save_sample_template,
"instance_token": self.instance_token,
"class_token": self.class_token,
"num_class_images": self.num_class_images,
"class_negative_prompt": self.class_negative_prompt,
"class_guidance_scale": self.class_guidance_scale,
"class_infer_steps": self.class_infer_steps,
"save_sample_negative_prompt": self.save_sample_negative_prompt,
"n_save_sample": self.n_save_sample,
"sample_seed": self.sample_seed,
"save_guidance_scale": self.save_guidance_scale,
"save_infer_steps": self.save_infer_steps
}
super().__init__(self_dict)
def is_valid(self):
if self.instance_data_dir is not None and self.instance_data_dir != "":
if os.path.exists(self.instance_data_dir):
return True
return False

View File

@ -2,13 +2,19 @@ import json
import os
import traceback
from extensions.sd_dreambooth_extension.dreambooth.utils import sanitize_name
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from modules import images, shared
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._- "))
class DreamboothConfig:
v2 = False
scheduler = "ddim"
lifetime_revision = 0
initial_revision = 0
def __init__(self,
model_name: str = "",
@ -22,6 +28,7 @@ class DreamboothConfig:
gradient_accumulation_steps: int = 1,
gradient_checkpointing: bool = True,
half_model: bool = False,
has_ema: bool = False,
hflip: bool = False,
learning_rate: float = 0.00000172,
lr_scheduler: str = 'constant',
@ -39,9 +46,11 @@ class DreamboothConfig:
resolution: int = 512,
revision: int = 0,
sample_batch_size: int = 1,
save_class_txt: bool = False,
save_embedding_every: int = 500,
save_preview_every: int = 500,
scale_lr: bool = False,
scheduler: str = "ddim",
src: str = "",
train_batch_size: int = 1,
train_text_encoder: bool = True,
@ -49,6 +58,7 @@ class DreamboothConfig:
use_concepts: bool = False,
use_cpu: bool = False,
use_ema: bool = True,
v2: bool = False,
c1_class_data_dir: str = "",
c1_class_guidance_scale: float = 7.5,
c1_class_infer_steps: int = 60,
@ -106,11 +116,7 @@ class DreamboothConfig:
c3_save_sample_negative_prompt: str = "",
c3_save_sample_prompt: str = "",
c3_save_sample_template: str = "",
pretrained_model_name_or_path="",
concepts_list=None,
v2=None,
scheduler= None,
has_ema=False,
**kwargs
):
if revision == "" or revision is None:
@ -153,6 +159,7 @@ class DreamboothConfig:
self.resolution = resolution
self.revision = int(revision)
self.sample_batch_size = sample_batch_size
self.save_class_txt = save_class_txt
self.save_embedding_every = save_embedding_every
self.save_preview_every = save_preview_every
self.scale_lr = scale_lr
@ -237,6 +244,7 @@ class DreamboothConfig:
"""
Save the config file
"""
self.lifetime_revision = self.initial_revision + self.revision
models_path = shared.cmd_opts.dreambooth_models_path
if models_path == "" or models_path is None:
models_path = os.path.join(shared.models_path, "dreambooth")
@ -323,80 +331,3 @@ def from_file(model_name):
return None
class Concept(dict):
def __init__(self, max_steps: int = -1, instance_data_dir: str = "", class_data_dir: str = "",
file_prompt_contents: str = "Description", instance_prompt: str = "", class_prompt: str = "",
save_sample_prompt: str = "", save_sample_template: str = "", instance_token: str = "",
class_token: str = "", num_class_images: int = 0, class_negative_prompt: str = "",
class_guidance_scale: float = 7.5, class_infer_steps: int = 60, save_sample_negative_prompt: str = "",
n_save_sample: int = 1, sample_seed: int = -1, save_guidance_scale: float = 7.5,
save_infer_steps: int = 60, input_dict=None):
if input_dict is None:
self.max_steps = max_steps
self.instance_data_dir = instance_data_dir
self.class_data_dir = class_data_dir
self.file_prompt_contents = file_prompt_contents
self.instance_prompt = instance_prompt
self.class_prompt = class_prompt
self.save_sample_prompt = save_sample_prompt
self.save_sample_template = save_sample_template
self.instance_token = instance_token
self.class_token = class_token
self.num_class_images = num_class_images
self.class_negative_prompt = class_negative_prompt
self.class_guidance_scale = class_guidance_scale
self.class_infer_steps = class_infer_steps
self.save_sample_negative_prompt = save_sample_negative_prompt
self.n_save_sample = n_save_sample
self.sample_seed = sample_seed
self.save_guidance_scale = save_guidance_scale
self.save_infer_steps = save_infer_steps
else:
self.max_steps = input_dict["max_steps"] if "max_steps" in input_dict else -1
self.instance_data_dir = input_dict["instance_data_dir"] if "instance_data_dir" in input_dict else ""
self.class_data_dir = input_dict["class_data_dir"] if "class_data_dir" in input_dict else ""
self.file_prompt_contents = input_dict["file_prompt_contents"] if "file_prompt_contents" in input_dict else "Description"
self.instance_prompt = input_dict["instance_prompt"] if "instance_prompt" in input_dict else ""
self.class_prompt = input_dict["class_prompt"] if "class_prompt" in input_dict else ""
self.save_sample_prompt = input_dict["save_sample_prompt"] if "save_sample_prompt" in input_dict else ""
self.save_sample_template = input_dict["save_sample_template"] if "save_sample_template" in input_dict else ""
self.instance_token = input_dict["instance_token"] if "instance_token" in input_dict else ""
self.class_token = input_dict["class_token"] if "class_token" in input_dict else ""
self.num_class_images = input_dict["num_class_images"] if "num_class_images" in input_dict else 0
self.class_negative_prompt = input_dict["class_negative_prompt"] if "class_negative_promt" in input_dict else ""
self.class_guidance_scale = input_dict["class_guidance_scale"] if "class_guidance_scale" in input_dict else 7.5
self.class_infer_steps = input_dict["class_infer_steps"] if "class_infer_steps" in input_dict else 60
self.save_sample_negative_prompt = input_dict["save_sample_negative_prompt"] if "save_sample_negative_prompt" in input_dict else ""
self.n_save_sample = input_dict["n_save_sample"] if "n_save_samples" in input_dict else 1
self.sample_seed = input_dict["sample_seed"] if "sample_seed" in input_dict else -1
self.save_guidance_scale = input_dict["save_guidance_scale"] if "save_guidance_scale" in input_dict else 7.5
self.save_infer_steps = input_dict["save_infer_steps"] if "save_infer_steps" in input_dict else 60
self_dict = {
"max_steps": self.max_steps,
"instance_data_dir": self.instance_data_dir,
"class_data_dir": self.class_data_dir,
"file_prompt_contents": self.file_prompt_contents,
"instance_prompt": self.instance_prompt,
"class_prompt": self.class_prompt,
"save_sample_prompt": self.save_sample_prompt,
"save_sample_template": self.save_sample_template,
"instance_token": self.instance_token,
"class_token": self.class_token,
"num_class_images": self.num_class_images,
"class_negative_prompt": self.class_negative_prompt,
"class_guidance_scale": self.class_guidance_scale,
"class_infer_steps": self.class_infer_steps,
"save_sample_negative_prompt": self.save_sample_negative_prompt,
"n_save_sample": self.n_save_sample,
"sample_seed": self.sample_seed,
"save_guidance_scale": self.save_guidance_scale,
"save_infer_steps": self.save_infer_steps
}
super().__init__(self_dict)
def is_valid(self):
if self.instance_data_dir is not None and self.instance_data_dir != "":
if os.path.exists(self.instance_data_dir):
return True
return False

View File

@ -4,7 +4,9 @@
import os
import os.path as osp
import re
import shutil
import traceback
import torch
@ -13,10 +15,6 @@ from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, printi,
reload_system_models
from modules import shared
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
@ -42,37 +40,8 @@ unet_conversion_map_resnet = [
]
unet_conversion_map_layer = []
UNET_PARAMS_MODEL_CHANNELS = 320
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
UNET_PARAMS_IMAGE_SIZE = 32 # unused
UNET_PARAMS_IN_CHANNELS = 4
UNET_PARAMS_OUT_CHANNELS = 4
UNET_PARAMS_NUM_RES_BLOCKS = 2
UNET_PARAMS_CONTEXT_DIM = 768
UNET_PARAMS_NUM_HEADS = 8
unet_params = {
"model_channels": 320,
"channel_mult": [1, 2, 4, 4],
"attention_resolutions": [4, 2, 1],
"image_size": 32, # unused
"in_channels": 4,
"out_channels": 4,
"num_res_blocks": 2,
"context_dim": 768,
"num_heads": 8
}
unet_v2_params = unet_params.copy()
unet_v2_params["num_heads"] = [5, 10, 20, 20]
unet_v2_params["attention_head_dim"] = [5, 10, 20, 20]
unet_v2_params["context_dim"] = 1024
VAE_PARAMS_Z_CHANNELS = 4
VAE_PARAMS_RESOLUTION = 256
VAE_PARAMS_IN_CHANNELS = 3
VAE_PARAMS_OUT_CH = 3
VAE_PARAMS_CH = 128
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
VAE_PARAMS_NUM_RES_BLOCKS = 2
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
@ -121,15 +90,6 @@ for j in range(2):
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
@ -227,139 +187,85 @@ def convert_vae_state_dict(vae_state_dict):
# =========================#
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
conversion_map_layer = []
for q in range(4):
for r in range(2):
hfd_res_prefix = f"down_blocks.{q}.resnets.{r}."
sdd_res_prefix = f"input_blocks.{3 * q + r + 1}.0."
conversion_map_layer.append((sdd_res_prefix, hfd_res_prefix))
if q < 3:
hfd_atn_prefix = f"down_blocks.{q}.attentions.{r}."
sdd_atn_prefix = f"input_blocks.{3 * q + r + 1}.1."
conversion_map_layer.append((sdd_atn_prefix, hfd_atn_prefix))
for r in range(3):
hfu_res_prefix = f"up_blocks.{q}.resnets.{r}."
sdu_res_prefix = f"output_blocks.{3 * q + r}.0."
conversion_map_layer.append((sdu_res_prefix, hfu_res_prefix))
if q > 0:
hfu_atn_prefix = f"up_blocks.{q}.attentions.{r}."
sdu_atn_prefix = f"output_blocks.{3 * q + r}.1."
conversion_map_layer.append((sdu_atn_prefix, hfu_atn_prefix))
if q < 3:
# no downsample in down_blocks.3
hfd_prefix = f"down_blocks.{q}.downsamplers.0.conv."
sdd_prefix = f"input_blocks.{3 * (q + 1)}.0.op."
conversion_map_layer.append((sdd_prefix, hfd_prefix))
# Text Encoder Conversion #
# =========================#
# no upsample in up_blocks.3
hfu_prefix = f"up_blocks.{q}.upsamplers.0."
sdu_prefix = f"output_blocks.{3 * q + 2}.{1 if q == 0 else 2}."
conversion_map_layer.append((sdu_prefix, hfu_prefix))
hfm_atn_prefix = "mid_block.attentions.0."
sdm_atn_prefix = "middle_block.1."
conversion_map_layer.append((sdm_atn_prefix, hfm_atn_prefix))
for r in range(2):
hfm_res_prefix = f"mid_block.resnets.{r}."
sdm_res_prefix = f"middle_block.{2 * r}."
conversion_map_layer.append((sdm_res_prefix, hfm_res_prefix))
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
('resblocks.', 'text_model.encoder.layers.'),
('ln_1', 'layer_norm1'),
('ln_2', 'layer_norm2'),
('.c_fc.', '.fc1.'),
('.c_proj.', '.fc2.'),
('.attn', '.self_attn'),
('ln_final.', 'transformer.text_model.final_layer_norm.'),
('token_embedding.weight', 'transformer.text_model.embeddings.token_embedding.weight'),
('positional_embedding', 'transformer.text_model.embeddings.position_embedding.weight')
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {'q': 0, 'k': 1, 'v': 2}
if v2:
conv_transformer_to_linear(new_state_dict)
def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if k.endswith('.self_attn.q_proj.weight') or k.endswith('.self_attn.k_proj.weight') or k.endswith(
'.self_attn.v_proj.weight'):
k_pre = k[:-len('.q_proj.weight')]
k_code = k[-len('q_proj.weight')]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if k.endswith('.self_attn.q_proj.bias') or k.endswith('.self_attn.k_proj.bias') or k.endswith(
'.self_attn.v_proj.bias'):
k_pre = k[:-len('.q_proj.bias')]
k_code = k[-len('q_proj.bias')]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
# if relabelled_key != k:
# print(f"{k} -> {relabelled_key}")
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + '.in_proj_weight'] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + '.in_proj_bias'] = torch.cat(tensors)
return new_state_dict
# Text Encoder Conversion #
# =========================#
# pretty much a no-op
def convert_text_enc_state_dict(text_enc_dict):
def convert_text_enc_state_dict(text_enc_dict: dict[str, torch.Tensor]):
return text_enc_dict
def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
def convert_key(conv_key):
if ".position_ids" in conv_key:
return None
# common
conv_key = conv_key.replace("text_model.encoder.", "transformer.")
conv_key = conv_key.replace("text_model.", "")
if "layers" in conv_key:
# resblocks conversion
conv_key = conv_key.replace(".layers.", ".resblocks.")
if ".layer_norm" in conv_key:
conv_key = conv_key.replace(".layer_norm", ".ln_")
elif ".mlp." in conv_key:
conv_key = conv_key.replace(".fc1.", ".c_fc.")
conv_key = conv_key.replace(".fc2.", ".c_proj.")
elif '.self_attn.out_proj' in conv_key:
conv_key = conv_key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif '.self_attn.' in conv_key:
conv_key = None
else:
raise ValueError(f"unexpected key in DiffUsers model: {conv_key}")
elif '.position_embedding' in conv_key:
conv_key = conv_key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif '.token_embedding' in conv_key:
conv_key = conv_key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif 'final_layer_norm' in conv_key:
conv_key = conv_key.replace("final_layer_norm", "ln_final")
return conv_key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
for key in keys:
if 'layers' in key and 'q_proj' in key:
key_q = key
key_k = key.replace("q_proj", "k_proj")
key_v = key.replace("q_proj", "v_proj")
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
return new_sd
def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, reload_models=False):
def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, reload_models=True):
"""
@param model_name: The model name to compile
@param half: Use FP16 when compiling the model
@param use_subdir: The model will be saved to a subdirectory of the checkpoints folder
@param reload_models: Whether to reload the system list of checkpoints.
@return: status: What happened, path: Checkpoint path
"""
unload_system_models()
@ -376,57 +282,76 @@ def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, re
models_path = ckpt_dir
config = from_file(model_name)
try:
if "use_subdir" in config.__dict__:
use_subdir = config["use_subdir"]
except:
print("Yeah, we can't use dict to find config values.")
if "use_subdir" in config.__dict__:
use_subdir = config["use_subdir"]
v2 = config.v2
total_steps = config.revision
if use_subdir:
os.makedirs(os.path.join(models_path, model_name))
out_file = os.path.join(models_path, model_name, f"{model_name}_{total_steps}.ckpt")
checkpoint_path = os.path.join(models_path, model_name, f"{model_name}_{total_steps}.ckpt")
else:
out_file = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
checkpoint_path = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
model_path = config.pretrained_model_name_or_path
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
printi("Converting unet...")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
unet_state_dict = convert_unet_state_dict(unet_state_dict)
#unet_state_dict = convert_unet_state_dict_to_sd(v2, unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
printi("Converting vae...")
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
printi("Converting text encoder...")
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
#text_enc_dict = convert_text_enc_state_dict(text_enc_dict) if not v2 else convert_text_encoder_state_dict_to_sd_v2(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
printi("Compiling new state dict...")
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if half:
state_dict = {k: v.half() for k, v in state_dict.items()}
try:
printi("Converting unet...")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
unet_state_dict = convert_unet_state_dict(unet_state_dict)
# unet_state_dict = convert_unet_state_dict_to_sd(v2, unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
printi("Converting vae...")
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
printi("Converting text encoder...")
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
print("Converting text enc dict for V2 model.")
# Need to add the tag 'transformer' in advance, so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
if not config.v2:
config.v2 = True
config.save()
v2 = True
else:
print("Converting text enc dict for V1 model.")
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if half:
print("Halving model.")
state_dict = {k: v.half() for k, v in state_dict.items()}
state_dict = {"global_step": config.revision, "state_dict": state_dict}
printi(f"Saving checkpoint to {checkpoint_path}...")
torch.save(state_dict, checkpoint_path)
if v2:
cfg_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "v2-inference-v.yaml")
cfg_dest = checkpoint_path.replace(".ckpt", ".yaml")
print(f"Copying config file to {cfg_dest}")
shutil.copyfile(cfg_file, cfg_dest)
except Exception as e:
print("Exception compiling checkpoint!")
traceback.print_exc()
return f"Exception compiling: {e}", ""
state_dict = {"state_dict": state_dict}
new_ckpt = {'global_step': config.revision, 'state_dict': state_dict}
printi(f"Saving checkpoint to {out_file}...")
torch.save(new_ckpt, out_file)
if v2:
cfg_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "v2-inference-v.yaml")
cfg_dest = out_file.replace(".ckpt", ".yaml")
print(f"Copying config file to {cfg_dest}")
shutil.copyfile(cfg_file, cfg_dest)
try:
del unet_state_dict
del vae_state_dict

View File

@ -2,24 +2,17 @@ import gc
import logging
import math
import os
import random
import traceback
from pathlib import Path
from typing import Optional
import torch
import torch.utils.checkpoint
from PIL import features
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers.utils import logging as dl
from huggingface_hub import HfFolder, whoami
from six import StringIO
from transformers import CLIPTextModel
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file, Concept
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.utils import reload_system_models, unload_system_models, printm
from modules import paths, shared, devices, sd_models
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.utils import reload_system_models, unload_system_models, printm, \
isset, list_features, is_image
from modules import shared, devices
try:
cmd_dreambooth_models_path = shared.cmd_opts.dreambooth_models_path
@ -27,7 +20,6 @@ except:
cmd_dreambooth_models_path = None
logger = logging.getLogger(__name__)
# define a Handler which writes DEBUG messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logger.addHandler(console)
@ -37,83 +29,13 @@ dl.set_verbosity_error()
mem_record = {}
def generate_sample_img(model_dir: str, save_sample_prompt: str, seed: str):
print("Gensample?")
if model_dir is None or model_dir == "":
return "Please select a model."
config = from_file(model_dir)
unload_system_models()
model_path = config.pretrained_model_name_or_path
image = None
if not os.path.exists(config.pretrained_model_name_or_path):
print(f"Model path '{config.pretrained_model_name_or_path}' doesn't exist.")
return f"Can't find diffusers model at {config.pretrained_model_name_or_path}.", None
try:
print(f"Loading model from {model_path}.")
text_enc_model = CLIPTextModel.from_pretrained(config.pretrained_model_name_or_path,
subfolder="text_encoder", revision=config.revision)
pipeline = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_enc_model,
torch_dtype=torch.float16,
revision=config.revision,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
)
pipeline = pipeline.to(shared.device)
pil_features = list_features()
save_dir = os.path.join(shared.sd_path, "outputs", "dreambooth")
db_model_path = config.model_dir
if save_sample_prompt is None:
msg = "Please provide a sample prompt."
print(msg)
return msg, None
shared.state.textinfo = f"Generating preview image for model {db_model_path}..."
# I feel like this might not actually be necessary...but what the heck.
if seed is None or seed == '' or seed == -1:
seed = int(random.randrange(21474836147))
g_cuda = torch.Generator(device=shared.device).manual_seed(seed)
sample_dir = os.path.join(save_dir, "samples")
os.makedirs(sample_dir, exist_ok=True)
file_count = 0
for x in Path(sample_dir).iterdir():
if is_image(x, pil_features):
file_count += 1
shared.state.job_count = 1
with torch.autocast("cuda"), torch.inference_mode():
image = pipeline(save_sample_prompt,
num_inference_steps=60,
guidance_scale=7.5,
scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085,
beta_end=0.012),
width=config.resolution,
height=config.resolution,
generator=g_cuda).images[0]
except:
print("Exception generating sample!")
traceback.print_exc()
reload_system_models()
return "Sample generated.", image
# Borrowed from https://wandb.ai/psuraj/dreambooth/reports/Training-Stable-Diffusion-with-Dreambooth
# --VmlldzoyNzk0NDc3#tl,dr; and https://www.reddit.com/r/StableDiffusion/comments/ybxv7h/good_dreambooth_formula/
def training_wizard_person(
model_dir
):
def training_wizard_person(model_dir):
return training_wizard(
model_dir,
is_person=True)
def training_wizard(
model_dir,
is_person=False
):
def training_wizard(model_dir, is_person=False):
"""
Calculate the number of steps based on our learning rate, return the following:
status,
@ -285,41 +207,6 @@ def performance_wizard():
train_batch_size, train_text_encoder, use_8bit_adam, use_cpu, use_ema
def dumb_safety(images, clip_input):
return images, False
def isset(val: str):
return val is not None and val != "" and val != "*"
def list_features():
# Create buffer for pilinfo() to write into rather than stdout
buffer = StringIO()
features.pilinfo(out=buffer)
pil_features = []
# Parse and analyse lines
for line in buffer.getvalue().splitlines():
if "Extensions:" in line:
ext_list = line.split(": ")[1]
extensions = ext_list.split(", ")
for extension in extensions:
if extension not in pil_features:
pil_features.append(extension)
return pil_features
def is_image(path: Path, feats=None):
if feats is None:
feats = []
if not len(feats):
feats = list_features()
is_img = path.is_file() and path.suffix.lower() in feats
return is_img
def load_params(model_dir):
data = from_file(model_dir)
msg = ""
@ -352,15 +239,41 @@ def load_params(model_dir):
ui_dict[f"c{c_idx}_{key}"] = ui_concept[key]
c_idx += 1
ui_dict["db_status"] = msg
ui_keys = ["db_adam_beta1", "db_adam_beta2", "db_adam_epsilon", "db_adam_weight_decay", "db_attention",
"db_center_crop", "db_concepts_path", "db_gradient_accumulation_steps", "db_gradient_checkpointing",
"db_half_model", "db_has_ema", "db_hflip", "db_learning_rate", "db_lr_scheduler", "db_lr_warmup_steps",
"db_max_grad_norm", "db_max_token_length", "db_max_train_steps", "db_mixed_precision", "db_model_path",
"db_not_cache_latents", "db_num_train_epochs", "db_pad_tokens", "db_pretrained_vae_name_or_path",
"db_prior_loss_weight", "db_resolution", "db_revision", "db_sample_batch_size",
"db_save_embedding_every", "db_save_preview_every", "db_scale_lr", "db_scheduler", "db_src",
"db_train_batch_size", "db_train_text_encoder", "db_use_8bit_adam", "db_use_concepts", "db_use_cpu",
"db_use_ema", "db_v2", "c1_class_data_dir", "c1_class_guidance_scale", "c1_class_infer_steps",
ui_keys = ["db_adam_beta1",
"db_adam_beta2",
"db_adam_epsilon",
"db_adam_weight_decay",
"db_attention",
"db_center_crop",
"db_concepts_path",
"db_gradient_accumulation_steps",
"db_gradient_checkpointing",
"db_half_model",
"db_hflip",
"db_learning_rate",
"db_lr_scheduler",
"db_lr_warmup_steps",
"db_max_grad_norm",
"db_max_token_length",
"db_max_train_steps",
"db_mixed_precision",
"db_not_cache_latents",
"db_num_train_epochs",
"db_pad_tokens",
"db_pretrained_vae_name_or_path",
"db_prior_loss_weight",
"db_resolution",
"db_sample_batch_size",
"db_save_class_txt",
"db_save_embedding_every",
"db_save_preview_every",
"db_scale_lr",
"db_train_batch_size",
"db_train_text_encoder",
"db_use_8bit_adam",
"db_use_concepts",
"db_use_cpu",
"db_use_ema", "c1_class_data_dir", "c1_class_guidance_scale", "c1_class_infer_steps",
"c1_class_negative_prompt", "c1_class_prompt", "c1_class_token", "c1_file_prompt_contents",
"c1_instance_data_dir", "c1_instance_prompt", "c1_instance_token", "c1_max_steps", "c1_n_save_sample",
"c1_num_class_images", "c1_sample_seed", "c1_save_guidance_scale", "c1_save_infer_steps",
@ -389,7 +302,24 @@ def load_params(model_dir):
return output
def start_training(model_dir: str, imagic_only: bool):
def load_model_params(model_dir):
data = from_file(model_dir)
msg = ""
if data is None:
print("Can't load config!")
msg = f"Error loading config for model '{model_dir}'."
return "", "", "", "", "", "", msg
else:
return data.model_dir, \
data.revision, \
"True" if data.v2 else "False", \
"True" if data.has_ema else "False", \
data.src, \
data.scheduler, \
""
def start_training(model_dir: str, imagic_only: bool, use_subdir: bool):
global mem_record
if model_dir == "" or model_dir is None:
print("Invalid model name.")
@ -421,55 +351,37 @@ def start_training(model_dir: str, imagic_only: bool):
if msg:
shared.state.textinfo = msg
print(msg)
return msg, msg, 0, ""
return msg, msg, 0, msg
# Clear memory and do "stuff" only after we've ensured all the things are right
print("Starting Dreambooth training...")
unload_system_models()
total_steps = config.revision
if imagic_only:
shared.state.textinfo = "Initializing imagic training..."
print(shared.state.textinfo)
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic
mem_record = train_imagic(config, mem_record)
else:
shared.state.textinfo = "Initializing dreambooth training..."
print(shared.state.textinfo)
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
config, mem_record, msg = main(config, mem_record)
if config.revision != total_steps:
config.save()
total_steps = config.revision
try:
if imagic_only:
shared.state.textinfo = "Initializing imagic training..."
print(shared.state.textinfo)
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic
mem_record = train_imagic(config, mem_record)
else:
shared.state.textinfo = "Initializing dreambooth training..."
print(shared.state.textinfo)
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
config, mem_record, msg = main(config, mem_record, use_subdir=use_subdir)
if config.revision != total_steps:
config.save()
total_steps = config.revision
res = f"Training {'interrupted' if shared.state.interrupted else 'finished'}. " \
f"Total lifetime steps: {total_steps} \n"
except Exception as e:
res = f"Exception training model: {e}"
pass
devices.torch_gc()
gc.collect()
printm("Training completed, reloading SD Model.")
print(f'Memory output: {mem_record}')
reload_system_models()
res = f"Training {'interrupted' if shared.state.interrupted else 'finished'}. " \
f"Total lifetime steps: {total_steps} \n"
print(f"Returning result: {res}")
return res, "", total_steps, ""
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def save_checkpoint(model_name: str, total_steps: int):
print(f"Successfully trained model for a total of {total_steps} steps, converting to ckpt.")
ckpt_dir = shared.cmd_opts.ckpt_dir
models_path = os.path.join(paths.models_path, "Stable-diffusion")
if ckpt_dir is not None:
models_path = ckpt_dir
src_path = os.path.join(
os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path, "dreambooth",
model_name, "working")
out_file = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
compile_checkpoint(model_name)
sd_models.list_models()
return res, res, total_steps, res

View File

@ -3,6 +3,7 @@ import re
import torch
import torch.utils.checkpoint
from torch.utils.data import Dataset
from transformers import CLIPTextModel
from modules import shared
@ -71,6 +72,41 @@ class FilenameTextGetter:
return output
class PromptDataset(Dataset):
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
def __init__(self, prompt: str, num_samples: int, filename_texts, file_prompt_contents: str, class_token: str,
instance_token: str):
self.prompt = prompt
self.instance_token = instance_token
self.class_token = class_token
self.num_samples = num_samples
self.filename_texts = filename_texts
self.file_prompt_contents = file_prompt_contents
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {"filename_text": self.filename_texts[index % len(self.filename_texts)] if len(
self.filename_texts) > 0 else ""}
prompt = example["filename_text"]
if "Instance" in self.file_prompt_contents:
class_token = self.class_token
# If the token is already in the prompt, just remove the instance token, don't swap it
class_tokens = [f"a {class_token}", f"the {class_token}", f"an {class_token}", class_token]
for token in class_tokens:
if token in prompt:
prompt = prompt.replace(self.instance_token, "")
else:
prompt = prompt.replace(self.instance_token, self.class_token)
prompt = self.prompt.replace("[filewords]", prompt)
example["prompt"] = prompt
example["index"] = index
return example
# Implementation from https://github.com/bmaltais/kohya_ss
def encode_hidden_state(text_encoder: CLIPTextModel, input_ids, pad_tokens, b_size, max_token_length,
tokenizer_max_length):

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """
import gc
import gradio as gr
import os
import re
import shutil
import gradio as gr
import torch
import modules.sd_models
@ -40,7 +40,7 @@ from diffusers import (
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
# HeunDiscreteScheduler, - Add after main diffusers is bumped on pypi.org
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
@ -466,7 +466,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
# From Bmalthais
# if v2:
# linear_transformer_to_conv(new_checkpoint)
# linear_transformer_to_conv(new_checkpoint)
return new_checkpoint, has_ema
@ -639,29 +639,30 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model
import re
textenc_conversion_lst = [
('cond_stage_model.model.positional_embedding',
"text_model.embeddings.position_embedding.weight"),
"text_model.embeddings.position_embedding.weight"),
('cond_stage_model.model.token_embedding.weight',
"text_model.embeddings.token_embedding.weight"),
"text_model.embeddings.token_embedding.weight"),
('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'),
('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias')
]
textenc_conversion_map = {x[0]:x[1] for x in textenc_conversion_lst}
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
textenc_transformer_conversion_lst = [
('resblocks.','text_model.encoder.layers.'),
('ln_1','layer_norm1'),
('ln_2','layer_norm2'),
('.c_fc.','.fc1.'),
('.c_proj.','.fc2.'),
('.attn','.self_attn'),
('ln_final.','transformer.text_model.final_layer_norm.'),
('token_embedding.weight','transformer.text_model.embeddings.token_embedding.weight'),
('positional_embedding','transformer.text_model.embeddings.position_embedding.weight')
('resblocks.', 'text_model.encoder.layers.'),
('ln_1', 'layer_norm1'),
('ln_2', 'layer_norm2'),
('.c_fc.', '.fc1.'),
('.c_proj.', '.fc2.'),
('.attn', '.self_attn'),
('ln_final.', 'transformer.text_model.final_layer_norm.'),
('token_embedding.weight', 'transformer.text_model.embeddings.token_embedding.weight'),
('positional_embedding', 'transformer.text_model.embeddings.position_embedding.weight')
]
protected = {re.escape(x[0]):x[1] for x in textenc_transformer_conversion_lst}
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
@ -669,30 +670,30 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict = {}
d_model = int( checkpoint['cond_stage_model.model.text_projection'].shape[0] )
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
text_model_dict["text_model.embeddings.position_ids"] = \
text_model.text_model.embeddings.get_buffer('position_ids')
text_model.text_model.embeddings.get_buffer('position_ids')
for key in keys:
if 'resblocks.23' in key: # Diffusers drops the final layer and only uses the penultimate layer
if 'resblocks.23' in key: # Diffusers drops the final layer and only uses the penultimate layer
continue
if key in textenc_conversion_map:
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
if key.startswith("cond_stage_model.model.transformer."):
new_key = key[len("cond_stage_model.model.transformer.") :]
new_key = key[len("cond_stage_model.model.transformer."):]
if new_key.endswith(".in_proj_weight"):
new_key = new_key[:-len(".in_proj_weight")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key+'.q_proj.weight'] = checkpoint[key][:d_model,:]
text_model_dict[new_key+'.k_proj.weight'] = checkpoint[key][d_model:d_model*2,:]
text_model_dict[new_key+'.v_proj.weight'] = checkpoint[key][d_model*2:,:]
text_model_dict[new_key + '.q_proj.weight'] = checkpoint[key][:d_model, :]
text_model_dict[new_key + '.k_proj.weight'] = checkpoint[key][d_model:d_model * 2, :]
text_model_dict[new_key + '.v_proj.weight'] = checkpoint[key][d_model * 2:, :]
elif new_key.endswith(".in_proj_bias"):
new_key = new_key[:-len(".in_proj_bias")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key+'.q_proj.bias'] = checkpoint[key][:d_model]
text_model_dict[new_key+'.k_proj.bias'] = checkpoint[key][d_model:d_model*2]
text_model_dict[new_key+'.v_proj.bias'] = checkpoint[key][d_model*2:]
text_model_dict[new_key + '.q_proj.bias'] = checkpoint[key][:d_model]
text_model_dict[new_key + '.k_proj.bias'] = checkpoint[key][d_model:d_model * 2]
text_model_dict[new_key + '.v_proj.bias'] = checkpoint[key][d_model * 2:]
else:
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
@ -731,180 +732,220 @@ def extract_checkpoint(new_model_name: str, ckpt_path: str, scheduler_type="ddim
db_config.src: The source checkpoint, if not from hub.
db_has_ema: Whether the model had EMA weights and they were extracted. If weights were not present or
you did not extract them and they were, this will be false.
db_resolution: The resolution the model trains at.
db_v2: Is this a V2 Model?
status
"""
new_model_name = sanitize_name(new_model_name)
print(f"new model URL is {new_model_url}")
shared.state.job_no = 0
checkpoint = None
map_location = shared.device
if shared.cmd_opts.ckptfix or shared.cmd_opts.medvram or shared.cmd_opts.lowvram:
printm(f"Using CPU for extraction.")
map_location = torch.device('cpu')
db_config = None
status = ""
has_ema = False
v2 = False
model_dir = ""
scheduler = ""
src = ""
revision = 0
# Try to determine if v1 or v2 model
if not from_hub:
checkpoint_info = modules.sd_models.get_closet_checkpoint_match(ckpt_path)
# Needed for V2 models so we can create the right text encoder.
if checkpoint_info is None:
print("Unable to find checkpoint file!")
shared.state.job_no = 8
return "", "", "", "", "", "", "", "Unable to find base checkpoint.", ""
reset_safe = False
if not shared.cmd_opts.disable_safe_unpickle:
reset_safe = True
shared.cmd_opts.disable_safe_unpickle = True
if not os.path.exists(checkpoint_info.filename):
print("Unable to find checkpoint file!")
shared.state.job_no = 8
return "", "", "", "", "", "", "", "Unable to find base checkpoint.", ""
try:
new_model_name = sanitize_name(new_model_name)
shared.state.job_no = 0
checkpoint = None
map_location = shared.device
if shared.cmd_opts.ckptfix or shared.cmd_opts.medvram or shared.cmd_opts.lowvram:
printm(f"Using CPU for extraction.")
map_location = torch.device('cpu')
ckpt_path = checkpoint_info[0]
# Try to determine if v1 or v2 model
if not from_hub:
printi("Loading model from checkpoint.")
checkpoint_info = modules.sd_models.get_closet_checkpoint_match(ckpt_path)
printi("Loading checkpoint...")
checkpoint = torch.load(ckpt_path)
# Todo: Decide if we should store this separately in the db_config and append it when re-compiling models.
# global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]
if checkpoint_info is None:
print("Unable to find checkpoint file!")
shared.state.job_no = 8
return "", "", 0, "", "", "", "", 512, "", "Unable to find base checkpoint."
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if not os.path.exists(checkpoint_info.filename):
print("Unable to find checkpoint file!")
shared.state.job_no = 8
return "", "", 0, "", "", "", "", 512, "", "Unable to find base checkpoint."
ckpt_path = checkpoint_info[0]
printi("Loading checkpoint...")
checkpoint = torch.load(ckpt_path)
# Todo: Decide if we should store this separately in the db_config and append it when re-compiling models.
# global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
else:
v2 = False
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
else:
v2 = False
if new_model_token == "" or new_model_token is None:
msg = "Please provide a token to load models from the hub."
print(msg)
return "", "", 0, "", "", "", "", 512, "", msg
printi("Loading model from hub.")
v2 = new_model_url == "stabilityai/stable-diffusion-2"
else:
if new_model_token == "" or new_model_token is None:
msg = "Please provide a token to load models from the hub."
return msg
# Todo: Find out if there's a better way to do this
v2 = new_model_url == "stabilityai/stable-diffusion-2"
if v2:
prediction_type = "v_prediction"
image_size = 768
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
"v2-inference-v.yaml")
else:
prediction_type = "epsilon"
image_size = 512
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
"v1-inference.yaml")
db_config = DreamboothConfig(model_name=new_model_name, scheduler=scheduler_type, v2=v2,
src=ckpt_path if not from_hub else new_model_url)
print(f"{'v2' if v2 else 'v1'} model loaded.")
original_config = OmegaConf.load(original_config_file)
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
if v2:
# All of the 2.0 models use OpenCLIP and all use DDPM scheduler by default.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
scheduler_type = "ddim"
if v2:
prediction_type = "v_prediction"
image_size = 768
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
"v2-inference-v.yaml")
else:
scheduler_type = "pndm"
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
# elif scheduler_type == "heun":
# scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
prediction_type = "epsilon"
image_size = 512
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
"v1-inference.yaml")
if from_hub:
print(f"Trying to create {new_model_name} from huggingface.co/{new_model_url}")
printi("Loading model from hub.")
pipe = DiffusionPipeline.from_pretrained(new_model_url, use_auth_token=new_model_token,
scheduler=scheduler)
printi("Model loaded.")
shared.state.job_no = 7
db_config = DreamboothConfig(model_name=new_model_name, scheduler=scheduler_type, v2=v2,
src=ckpt_path if not from_hub else new_model_url, resolution=768 if v2 else 512)
print(f"{'v2' if v2 else 'v1'} model loaded.")
else:
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet = UNet2DConditionModel(**unet_config)
original_config = OmegaConf.load(original_config_file)
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=ckpt_path, extract_ema=extract_ema
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
db_config.has_ema = has_ema
db_config.save()
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
print("Creating scheduler...")
if v2:
# All of the 2.0 models use OpenCLIP and all use DDPM scheduler by default.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
scheduler_type = "ddim"
else:
scheduler_type = "pndm"
db_config.scheduler = scheduler_type
db_config.save()
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipe is None:
print("Unable to create da pipe!")
status = "Unable to create pipeline for extraction."
if from_hub:
print(f"Trying to create {new_model_name} from huggingface.co/{new_model_url}")
printi("Loading model from hub.")
pipe = DiffusionPipeline.from_pretrained(new_model_url, use_auth_token=new_model_token,
scheduler=scheduler, device_map=map_location)
printi("Model loaded.")
shared.state.job_no = 7
else:
printi("Converting unet...")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=ckpt_path, extract_ema=extract_ema
)
db_config.has_ema = has_ema
db_config.save()
unet.load_state_dict(converted_unet_checkpoint)
printi("Converting vae...")
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
printi("Converting text encoder...")
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
except Exception as e:
pipe = None
status = f"Exception while extracting model: {e}"
if pipe is None or db_config is None:
print("Pipeline or config is not set, unable to continue.")
else:
printi("Saving diffusers model...")
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
model_dir = db_config.model_dir
revision = db_config.revision
scheduler = db_config.scheduler
src = db_config.src
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
for req_dir in required_dirs:
full_path = os.path.join(db_config.pretrained_model_name_or_path, req_dir)
if not os.path.exists(full_path):
status = f"Missing model directory, removing model: {full_path}"
shutil.rmtree(db_config.model_dir, ignore_errors=False, onerror=None)
break
if reset_safe:
shared.cmd_opts.disable_safe_unpickle = False
reload_system_models()
printm(status, True)
@ -912,10 +953,11 @@ def extract_checkpoint(new_model_name: str, ckpt_path: str, scheduler_type="ddim
dirs = get_db_models()
return gr.Dropdown.update(choices=sorted(dirs), value=new_model_name), \
db_config.model_dir, \
db_config.revision, \
db_config.scheduler, \
db_config.src, \
"True" if db_config.has_ema else "False", \
"True" if db_config.v2 else "False", \
model_dir, \
revision, \
scheduler, \
src, \
"True" if has_ema else "False", \
"True" if v2 else "False", \
db_config.resolution, \
status

View File

@ -1,4 +1,3 @@
# From shivam shiaro's repo, with "minimal" modification to hopefully allow for smoother updating?
import argparse
import hashlib
import itertools
@ -15,7 +14,7 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, \
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler, \
DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
@ -28,10 +27,11 @@ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
from extensions.sd_dreambooth_extension.dreambooth import xattention
from extensions.sd_dreambooth_extension.dreambooth.SuperDataset import SuperDataset
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import save_checkpoint, list_features, \
is_image, printm
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, encode_hidden_state
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, sanitize_name
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import printm
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, encode_hidden_state, \
PromptDataset
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, sanitize_name, list_features, is_image
from modules import shared
# Custom stuff
@ -60,14 +60,15 @@ dl.set_verbosity_error()
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "text_encoder_cls":
from transformers import text_encoder_cls
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return text_encoder_cls
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
@ -128,54 +129,6 @@ def parse_args(input_args=None):
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--save_sample_prompt",
type=str,
default=None,
help="The prompt used to generate sample outputs to save.",
)
parser.add_argument(
"--save_sample_negative_prompt",
type=str,
default=None,
help="The negative prompt used to generate sample outputs to save.",
)
parser.add_argument(
"--n_save_sample",
type=int,
default=4,
help="The number of samples to save.",
)
parser.add_argument(
"--save_guidance_scale",
type=float,
default=7.5,
help="CFG for save sample.",
)
parser.add_argument(
"--save_infer_steps",
type=int,
default=50,
help="The number of inference steps for save sample.",
)
parser.add_argument(
"--class_negative_prompt",
type=str,
default=None,
help="The negative prompt used to generate sample outputs to save.",
)
parser.add_argument(
"--class_guidance_scale",
type=float,
default=7.5,
help="CFG for save sample.",
)
parser.add_argument(
"--class_infer_steps",
type=int,
default=50,
help="The number of inference steps for save sample.",
)
parser.add_argument(
"--pad_tokens",
default=False,
@ -194,8 +147,8 @@ def parse_args(input_args=None):
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
@ -231,6 +184,7 @@ def parse_args(input_args=None):
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
@ -291,18 +245,15 @@ def parse_args(input_args=None):
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--not_cache_latents", action="store_true",
@ -369,44 +320,20 @@ def parse_args(input_args=None):
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
if args.class_data_dir is not None:
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
return args
class PromptDataset(Dataset):
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
def __init__(self, prompt: str, num_samples: int, filename_texts, file_prompt_contents: str, class_token: str,
instance_token: str):
self.prompt = prompt
self.instance_token = instance_token
self.class_token = class_token
self.num_samples = num_samples
self.filename_texts = filename_texts
self.file_prompt_contents = file_prompt_contents
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {"filename_text": self.filename_texts[index % len(self.filename_texts)] if len(
self.filename_texts) > 0 else ""}
prompt = example["filename_text"]
if "Instance" in self.file_prompt_contents:
class_token = self.class_token
# If the token is already in the prompt, just remove the instance token, don't swap it
class_tokens = [f"a {class_token}", f"the {class_token}", f"an {class_token}", class_token]
for token in class_tokens:
if token in prompt:
prompt = prompt.replace(self.instance_token, "")
else:
prompt = prompt.replace(self.instance_token, self.class_token)
prompt = self.prompt.replace("[filewords]", prompt)
example["prompt"] = prompt
example["index"] = index
return example
class LatentsDataset(Dataset):
def __init__(self, latents_cache, text_encoder_cache, concepts_cache):
self.latents_cache = latents_cache
@ -449,7 +376,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"
def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict, str]:
def main(args: DreamboothConfig, memory_record, use_subdir) -> tuple[DreamboothConfig, dict, str]:
global with_prior
text_encoder = None
args.tokenizer_name = None
@ -540,7 +467,6 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
shared.state.job_count = num_new_images
shared.state.job_no = 0
save_txt = "[filewords]" in concept.class_prompt
filename_texts = [text_getter.read_text(x) for x in Path(concept.instance_data_dir).iterdir() if
is_image(x, pil_features)]
sample_dataset = PromptDataset(concept.class_prompt, num_new_images, filename_texts,
@ -570,15 +496,17 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
logger.debug("Generation canceled.")
shared.state.textinfo = "Training canceled."
return args, mem_record, "Training canceled."
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
if args.save_class_txt:
image_base = hashlib.sha1(image.tobytes()).hexdigest()
else:
image_base = example["prompt"]
image_filename = class_images_dir / f"{generated_images + cur_class_images}-" \
f"{hash_image}.jpg"
f"{image_base}.jpg"
image.save(image_filename)
if save_txt:
txt_filename = class_images_dir / f"{generated_images + cur_class_images}-" \
f"{hash_image}.txt "
if args.save_class_txt:
txt_filename = image_filename.replace(".jpg", ".txt")
with open(txt_filename, "w", encoding="utf8") as file:
file.write(example["filename_text"][i] + "\n")
file.write(example["prompt"])
shared.state.job_no += 1
generated_images += 1
shared.state.textinfo = f"Class image {generated_images}/{num_new_images}, " \
@ -597,15 +525,20 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(args.pretrained_model_name_or_path, "tokenizer"),
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
@ -658,6 +591,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
use_adam = True
except Exception as a:
logger.warning(f"Exception importing 8bit adam: {a}")
traceback.print_exc()
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
@ -670,7 +604,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
eps=args.adam_epsilon,
)
noise_scheduler = DDIMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
def cleanup_memory():
try:
@ -748,6 +682,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
return output
train_dataloader = torch.utils.data.DataLoader(
#train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
)
# Move text_encoder and VAE to GPU.
@ -890,7 +825,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
pred_type = "epsilon"
if args.v2:
pred_type = "v_prediction"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", steps_offset=1,
clip_sample=False, set_alpha_to_one=False, prediction_type=pred_type)
s_pipeline = DiffusionPipeline.from_pretrained(
@ -911,7 +846,8 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
shared.state.textinfo = f"Saving checkpoint at step {args.revision}..."
try:
s_pipeline.save_pretrained(args.pretrained_model_name_or_path)
save_checkpoint(args.model_name, args.revision)
compile_checkpoint(args.model_name, half=args.half_model, use_subdir=use_subdir,
reload_models=False)
except Exception as e:
logger.debug(f"Exception saving checkpoint/model: {e}")
traceback.print_exc()
@ -965,6 +901,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
training_complete = False
msg = ""
weights_saved = False
for epoch in range(args.num_train_epochs):
if training_complete:
break
@ -973,6 +910,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
if args.train_text_encoder and text_encoder is not None:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
weights_saved = False
with accelerator.accumulate(unet):
# Convert images to latent space
with torch.no_grad():
@ -1052,12 +990,8 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
progress_bar.update(1)
global_step += 1
args.revision += 1
shared.state.job_no = global_step
training_complete = global_step >= args.max_train_steps or shared.state.interrupted
if global_step > 0:
save_img = args.save_preview_every and not global_step % args.save_preview_every
save_model = args.save_embedding_every and not global_step % args.save_embedding_every
@ -1067,6 +1001,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
if save_img or save_model:
save_weights()
args.save()
weights_saved = True
shared.state.job_count = args.max_train_steps
if shared.state.interrupted:
training_complete = True
@ -1086,11 +1021,22 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
f" total."
break
progress_bar.update(1)
global_step += 1
args.revision += 1
shared.state.job_no = global_step
training_complete = global_step >= args.max_train_steps or shared.state.interrupted
accelerator.wait_for_everyone()
if not args.not_cache_latents:
train_dataset, train_dataloader = cache_latents(enc_vae=vae, orig_dataset=gen_dataset)
if training_complete:
if not weights_saved:
save_img = True
save_model = True
save_weights()
args.save()
msg = f"Training completed, total steps: {args.revision}"
break
except Exception as m:

View File

@ -14,7 +14,8 @@ from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import list_features, is_image, printm
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import printm
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import AverageMeter
from modules import shared

View File

@ -1,8 +1,20 @@
import gc
import json
import os
import random
import traceback
from io import StringIO
from pathlib import Path
from typing import Optional
import torch
from PIL import features
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
from huggingface_hub import HfFolder, whoami
from transformers import AutoTokenizer, CLIPTextModel
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, PromptDataset
from modules import shared, paths
try:
@ -92,3 +104,160 @@ def reload_system_models():
if shared.sd_model is not None:
shared.sd_model.to(shared.device)
printm("Restored system models.")
def debug_prompts(model_dir):
from extensions.sd_dreambooth_extension.dreambooth.SuperDataset import SuperDataset
if model_dir is None or model_dir == "":
return "Please select a model."
config = from_file(model_dir)
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(config.pretrained_model_name_or_path, "tokenizer"),
revision=config.revision,
use_fast=False,
)
train_dataset = SuperDataset(
concepts_list=config.concepts_list,
tokenizer=tokenizer,
size=config.resolution,
center_crop=config.center_crop,
lifetime_steps=config.revision,
pad_tokens=config.pad_tokens,
hflip=config.hflip,
max_token_length=config.max_token_length
)
output = {"instance_prompts": [], "existing_class_prompts": [], "new_class_prompts": [], "sample_prompts": []}
for i in range(train_dataset.__len__()):
item = train_dataset.__getitem__(i)
output["instance_prompts"].append(item["instance_prompt"])
output["existing_class_prompts"].append(item["class_prompt"])
output["sample_prompts"] = train_dataset.get_sample_prompts()
for concept in config.concepts_list:
text_getter = FilenameTextGetter()
c_idx = 0
class_images_dir = Path(concept["class_data_dir"])
if class_images_dir == "" or class_images_dir is None or class_images_dir == shared.script_path:
class_images_dir = os.path.join(config.model_dir, f"classifiers_{c_idx}")
print(f"Class image dir is not set, defaulting to {class_images_dir}")
class_images_dir.mkdir(parents=True, exist_ok=True)
cur_class_images = 0
iterfiles = 0
pil_features = list_features()
for x in class_images_dir.iterdir():
iterfiles += 1
if is_image(x, pil_features):
cur_class_images += 1
if cur_class_images < concept.num_class_images:
num_new_images = concept.num_class_images - cur_class_images
filename_texts = [text_getter.read_text(x) for x in Path(concept.instance_data_dir).iterdir() if
is_image(x, pil_features)]
sample_dataset = PromptDataset(concept.class_prompt, num_new_images, filename_texts,
concept.file_prompt_contents, concept.class_token,
concept.instance_token)
for i in range(sample_dataset):
output["new_class_prompts"].append(sample_dataset.__getitem__(i)["prompt"])
c_idx += 1
return json.dumps(output)
def generate_sample_img(model_dir: str, save_sample_prompt: str, seed: str):
if model_dir is None or model_dir == "":
return "Please select a model."
config = from_file(model_dir)
unload_system_models()
model_path = config.pretrained_model_name_or_path
image = None
if not os.path.exists(config.pretrained_model_name_or_path):
print(f"Model path '{config.pretrained_model_name_or_path}' doesn't exist.")
return f"Can't find diffusers model at {config.pretrained_model_name_or_path}.", None
try:
print(f"Loading model from {model_path}.")
text_enc_model = CLIPTextModel.from_pretrained(config.pretrained_model_name_or_path,
subfolder="text_encoder", revision=config.revision)
pipeline = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_enc_model,
torch_dtype=torch.float16,
revision=config.revision,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
)
pipeline = pipeline.to(shared.device)
pil_features = list_features()
save_dir = os.path.join(shared.sd_path, "outputs", "dreambooth")
db_model_path = config.model_dir
if save_sample_prompt is None:
msg = "Please provide a sample prompt."
print(msg)
return msg, None
shared.state.textinfo = f"Generating preview image for model {db_model_path}..."
# I feel like this might not actually be necessary...but what the heck.
if seed is None or seed == '' or seed == -1:
seed = int(random.randrange(21474836147))
g_cuda = torch.Generator(device=shared.device).manual_seed(seed)
sample_dir = os.path.join(save_dir, "samples")
os.makedirs(sample_dir, exist_ok=True)
file_count = 0
for x in Path(sample_dir).iterdir():
if is_image(x, pil_features):
file_count += 1
shared.state.job_count = 1
with torch.autocast("cuda"), torch.inference_mode():
image = pipeline(save_sample_prompt,
num_inference_steps=60,
guidance_scale=7.5,
scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085,
beta_end=0.012),
width=config.resolution,
height=config.resolution,
generator=g_cuda).images[0]
except:
print("Exception generating sample!")
traceback.print_exc()
reload_system_models()
return "Sample generated.", image
def isset(val: str):
return val is not None and val != "" and val != "*"
def list_features():
# Create buffer for pilinfo() to write into rather than stdout
buffer = StringIO()
features.pilinfo(out=buffer)
pil_features = []
# Parse and analyse lines
for line in buffer.getvalue().splitlines():
if "Extensions:" in line:
ext_list = line.split(": ")[1]
extensions = ext_list.split(", ")
for extension in extensions:
if extension not in pil_features:
pil_features.append(extension)
return pil_features
def is_image(path: Path, feats=None):
if feats is None:
feats = []
if not len(feats):
feats = list_features()
is_img = path.is_file() and path.suffix.lower() in feats
return is_img
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"

View File

@ -6,9 +6,8 @@ import sys
import sysconfig
import git
from launch import run
from modules.paths import script_path
from launch import run
if sys.version_info < (3, 8):
import importlib_metadata
@ -27,19 +26,21 @@ def check_versions():
splits = line.split("==")
if len(splits) == 2:
key = splits[0]
if "torch" not in key:
if "diffusers" in key:
key = "diffusers"
reqs_dict[key] = splits[1].replace("\n", "").strip()
checks = ["bitsandbytes", "diffusers", "transformers", "xformers"]
reqs_dict[key] = splits[1].replace("\n", "").strip()
# print(f"Reqs dict: {reqs_dict}")
reqs_dict["diffusers[torch]"] = "0.10.0.dev0"
checks = ["bitsandbytes", "diffusers[torch]", "transformers", "xformers", "torch", "torchvision"]
for check in checks:
if check == "diffusers[torch]":
il_check = "diffusers"
else:
il_check = check
check_ver = "N/A"
status = "[ ]"
try:
check_available = importlib.util.find_spec(check) is not None
check_available = importlib.util.find_spec(il_check) is not None
if check_available:
check_ver = importlib_metadata.version(check)
check_ver = importlib_metadata.version(il_check)
if check in reqs_dict:
req_version = reqs_dict[check]
if str(check_ver) == str(req_version):
@ -55,31 +56,41 @@ def check_versions():
print(f"{status} {check} version {check_ver} installed.")
dreambooth_skip_install = os.environ.get('DREAMBOOTH_SKIP_INSTALL', False)
if not dreambooth_skip_install:
name = "Dreambooth"
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...",
f"Couldn't install {name} requirements.")
base_dir = os.path.dirname(os.path.realpath(__file__))
repo = git.Repo(base_dir)
revision = repo.rev_parse("HEAD")
print(f"Dreambooth revision is {revision}")
check_versions()
app_repo = git.Repo(os.path.join(base_dir, "..", ".."))
app_revision = app_repo.rev_parse("HEAD")
print("#######################################################################################################")
print("Initializing Dreambooth")
print("If submitting an issue on github, please provide the below text for debugging purposes:")
print("")
print(f"Python revision: {sys.version}")
print(f"Dreambooth revision: {revision}")
print(f"SD-WebUI revision: {app_revision}")
print("")
dreambooth_skip_install = os.environ.get('DREAMBOOTH_SKIP_INSTALL', False)
if not dreambooth_skip_install:
name = "Dreambooth"
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...",
f"Couldn't install {name} requirements.")
# Check for "different" B&B Files and copy only if necessary
if os.name == "nt":
python = sys.executable
bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bitsandbytes_windows")
bnb_dest = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
printed = False
filecmp.clear_cache()
copied = False
for file in os.listdir(bnb_src):
src_file = os.path.join(bnb_src, file)
if file == "main.py":
if file == "main.py" or file == "paths.py":
dest = os.path.join(bnb_dest, "cuda_setup")
else:
dest = bnb_dest
dest_file = os.path.join(dest, file)
shutil.copy2(src_file, dest)
check_versions()
print("#######################################################################################################")

View File

@ -5,7 +5,7 @@ albumentations==1.3.0
basicsr==1.4.2
bitsandbytes==0.35.0
clean-fid==0.1.29
diffusers[torch]==0.9.0
git+https://github.com/huggingface/diffusers#egg=diffusers[torch]
einops==0.4.1
fairscale==0.4.9
fastapi==0.87.0

View File

@ -2,11 +2,12 @@ import gradio as gr
from extensions.sd_dreambooth_extension.dreambooth import dreambooth
from extensions.sd_dreambooth_extension.dreambooth.db_config import save_config
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd_bmalthais import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import performance_wizard, \
training_wizard, training_wizard_person, generate_sample_img
training_wizard, training_wizard_person, load_model_params
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.utils import get_db_models, log_memory
from extensions.sd_dreambooth_extension.dreambooth.utils import get_db_models, log_memory, generate_sample_img, \
debug_prompts
from modules import script_callbacks, sd_models, shared
from modules.ui import setup_progressbar, gr_show, wrap_gradio_call, create_refresh_button
from webui import wrap_gradio_gpu_call
@ -67,7 +68,7 @@ def on_ui_tabs():
db_new_model_extract_ema = gr.Checkbox(label='Extract EMA Weights', value=False)
db_new_model_scheduler = gr.Dropdown(label='Scheduler', choices=["pndm", "lms", "euler",
"euler-ancestral", "dpm", "ddim"],
value="euler-ancestral")
value="ddim")
with gr.Row():
with gr.Column(scale=3):
@ -99,10 +100,11 @@ def on_ui_tabs():
db_lr_warmup_steps = gr.Number(label="Learning Rate Warmup Steps", precision=0, value=0)
with gr.Column():
gr.HTML(value="Instance Image Processing")
gr.HTML(value="Image Processing")
db_resolution = gr.Number(label="Resolution", precision=0, value=512)
db_center_crop = gr.Checkbox(label="Center Crop", value=False)
db_hflip = gr.Checkbox(label="Apply Horizontal Flip", value=True)
db_save_class_txt = gr.Checkbox(label="Save Class Captions to txt", value=False)
db_pretrained_vae_name_or_path = gr.Textbox(label='Pretrained VAE Name or Path',
placeholder="Leave blank to use base model VAE.",
value="")
@ -183,6 +185,7 @@ def on_ui_tabs():
with gr.Tab("Debugging"):
with gr.Column():
db_debug_prompts = gr.Button(value="Preview Prompts")
db_generate_sample = gr.Button(value="Generate Sample Image")
db_sample_prompt = gr.Textbox(label="Sample Prompt")
db_sample_seed = gr.Textbox(label="Sample Seed")
@ -201,6 +204,13 @@ def on_ui_tabs():
"one": "one",
"two": "two"
}
db_debug_prompts.click(
fn=debug_prompts,
inputs=[db_model_name],
outputs=[db_status]
)
db_save_params.click(
fn=save_config,
inputs=[
@ -215,6 +225,7 @@ def on_ui_tabs():
db_gradient_accumulation_steps,
db_gradient_checkpointing,
db_half_model,
db_has_ema,
db_hflip,
db_learning_rate,
db_lr_scheduler,
@ -232,9 +243,11 @@ def on_ui_tabs():
db_resolution,
db_revision,
db_sample_batch_size,
db_save_class_txt,
db_save_embedding_every,
db_save_preview_every,
db_scale_lr,
db_scheduler,
db_src,
db_train_batch_size,
db_train_text_encoder,
@ -242,6 +255,7 @@ def on_ui_tabs():
db_use_concepts,
db_use_cpu,
db_use_ema,
db_v2,
c1_class_data_dir,
c1_class_guidance_scale,
c1_class_infer_steps,
@ -318,7 +332,6 @@ def on_ui_tabs():
db_gradient_accumulation_steps,
db_gradient_checkpointing,
db_half_model,
db_has_ema,
db_hflip,
db_learning_rate,
db_lr_scheduler,
@ -327,27 +340,23 @@ def on_ui_tabs():
db_max_token_length,
db_max_train_steps,
db_mixed_precision,
db_model_path,
db_not_cache_latents,
db_num_train_epochs,
db_pad_tokens,
db_pretrained_vae_name_or_path,
db_prior_loss_weight,
db_resolution,
db_revision,
db_sample_batch_size,
db_save_class_txt,
db_save_embedding_every,
db_save_preview_every,
db_scale_lr,
db_scheduler,
db_src,
db_train_batch_size,
db_train_text_encoder,
db_use_8bit_adam,
db_use_concepts,
db_use_cpu,
db_use_ema,
db_v2,
c1_class_data_dir,
c1_class_guidance_scale,
c1_class_infer_steps,
@ -427,6 +436,12 @@ def on_ui_tabs():
]
)
db_model_name.change(
fn=load_model_params,
inputs=[db_model_name],
outputs=[db_model_path, db_revision, db_v2, db_has_ema, db_src, db_scheduler, db_status]
)
db_use_concepts.change(
fn=lambda x: {
concept_tab: gr_show(x is True)
@ -506,7 +521,7 @@ def on_ui_tabs():
)
db_generate_checkpoint.click(
fn=wrap_gradio_gpu_call(compile_checkpoint, extra_outputs=[gr.update()]),
fn=wrap_gradio_gpu_call(compile_checkpoint),
_js="db_start_progress",
inputs=[
db_model_name,
@ -532,7 +547,8 @@ def on_ui_tabs():
db_new_model_extract_ema
],
outputs=[
db_model_name, db_model_path, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_status
db_model_name, db_model_path, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution,
db_status
]
)
@ -541,7 +557,8 @@ def on_ui_tabs():
_js="db_save_start_progress",
inputs=[
db_model_name,
db_train_imagic_only
db_train_imagic_only,
db_use_subdir
],
outputs=[
db_status,