Fixes and Refactoring and Cleanup, Oh My...
Remove annoying message from bnb import, fix missing os import. Move db_concept to it's own thing so we can preview prompts before training. Fix dumb issues when loading/saving configs. Add setting for saving class txt prompts, versus just doing it. V2 conversion fixes. Move methods not specifically related to dreambooth training to utils class. Add automatic setter for UI details when switching models. Loading model params won't overwrite v2/ema/revision settings. Cleanup installer, better logging, etc. Use github diffusers version, for now.pull/422/head
parent
48f5811792
commit
9f578527e3
|
|
@ -111,12 +111,7 @@ def get_compute_capability(cuda):
|
|||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
print('')
|
||||
print('='*35 + 'BUG REPORT' + '='*35)
|
||||
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
print('='*80)
|
||||
if os.name == "NT":
|
||||
if os.name == "nt":
|
||||
print("Using magick windows DLL!")
|
||||
return "libbitsandbytes_cudaall.dll" # $$$
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
import errno
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
from warnings import warn
|
||||
|
||||
from .env_vars import get_potentially_lib_path_containing_env_vars
|
||||
|
||||
CUDA_RUNTIME_LIB: str = "libcudart.so" if os.name != "nt" else "cudart64_110.dll"
|
||||
|
||||
|
||||
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
|
||||
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
|
||||
|
||||
|
||||
def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
|
||||
existent_directories: Set[Path] = set()
|
||||
for path in candidate_paths:
|
||||
try:
|
||||
if path.exists():
|
||||
existent_directories.add(path)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.ENAMETOOLONG:
|
||||
raise exc
|
||||
|
||||
non_existent_directories: Set[Path] = candidate_paths - existent_directories
|
||||
if non_existent_directories:
|
||||
warn(
|
||||
"WARNING: The following directories listed in your path were found to "
|
||||
f"be non-existent: {non_existent_directories}"
|
||||
)
|
||||
|
||||
return existent_directories
|
||||
|
||||
|
||||
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
|
||||
return {
|
||||
path / CUDA_RUNTIME_LIB
|
||||
for path in candidate_paths
|
||||
if (path / CUDA_RUNTIME_LIB).is_file()
|
||||
}
|
||||
|
||||
|
||||
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
|
||||
"""
|
||||
Searches a given environmental var for the CUDA runtime library,
|
||||
i.e. `libcudart.so`.
|
||||
"""
|
||||
return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))
|
||||
|
||||
|
||||
def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
|
||||
return get_cuda_runtime_lib_paths(
|
||||
resolve_paths_list(paths_list_candidate)
|
||||
)
|
||||
|
||||
|
||||
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
|
||||
if len(results_paths) > 1:
|
||||
warning_msg = (
|
||||
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
|
||||
"We'll flip a coin and try one of these, in order to fail forward.\n"
|
||||
"Either way, this might cause trouble in the future:\n"
|
||||
"If you get `CUDA error: invalid device function` errors, the above "
|
||||
"might be the cause and the solution is to make sure only one "
|
||||
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env."
|
||||
)
|
||||
warn(warning_msg)
|
||||
|
||||
|
||||
def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
||||
"""
|
||||
Searches for a cuda installations, in the following order of priority:
|
||||
1. active conda env
|
||||
2. LD_LIBRARY_PATH
|
||||
3. any other env vars, while ignoring those that
|
||||
- are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
|
||||
- don't contain the path separator `/`
|
||||
|
||||
If multiple libraries are found in part 3, we optimistically try one,
|
||||
while giving a warning message.
|
||||
"""
|
||||
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
|
||||
|
||||
if "CONDA_PREFIX" in candidate_env_vars:
|
||||
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
|
||||
|
||||
conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
|
||||
warn_in_case_of_duplicates(conda_cuda_libs)
|
||||
|
||||
if conda_cuda_libs:
|
||||
return next(iter(conda_cuda_libs))
|
||||
|
||||
warn(
|
||||
f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
|
||||
)
|
||||
|
||||
if "LD_LIBRARY_PATH" in candidate_env_vars:
|
||||
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
|
||||
|
||||
if lib_ld_cuda_libs:
|
||||
return next(iter(lib_ld_cuda_libs))
|
||||
warn_in_case_of_duplicates(lib_ld_cuda_libs)
|
||||
|
||||
warn(
|
||||
f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
|
||||
)
|
||||
|
||||
remaining_candidate_env_vars = {
|
||||
env_var: value for env_var, value in candidate_env_vars.items()
|
||||
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
|
||||
}
|
||||
|
||||
cuda_runtime_libs = set()
|
||||
for env_var, value in remaining_candidate_env_vars.items():
|
||||
cuda_runtime_libs.update(find_cuda_lib_in(value))
|
||||
|
||||
if len(cuda_runtime_libs) == 0:
|
||||
print('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
|
||||
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
|
||||
|
||||
warn_in_case_of_duplicates(cuda_runtime_libs)
|
||||
|
||||
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
|
||||
|
|
@ -7,8 +7,8 @@ from torch.utils.data import Dataset
|
|||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import is_image, list_features
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter
|
||||
from modules import images
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
|
||||
|
||||
class Concept(dict):
|
||||
def __init__(self, max_steps: int = -1, instance_data_dir: str = "", class_data_dir: str = "",
|
||||
file_prompt_contents: str = "Description", instance_prompt: str = "", class_prompt: str = "",
|
||||
save_sample_prompt: str = "", save_sample_template: str = "", instance_token: str = "",
|
||||
class_token: str = "", num_class_images: int = 0, class_negative_prompt: str = "",
|
||||
class_guidance_scale: float = 7.5, class_infer_steps: int = 60, save_sample_negative_prompt: str = "",
|
||||
n_save_sample: int = 1, sample_seed: int = -1, save_guidance_scale: float = 7.5,
|
||||
save_infer_steps: int = 60, input_dict=None):
|
||||
if input_dict is None:
|
||||
self.max_steps = max_steps
|
||||
self.instance_data_dir = instance_data_dir
|
||||
self.class_data_dir = class_data_dir
|
||||
self.file_prompt_contents = file_prompt_contents
|
||||
self.instance_prompt = instance_prompt
|
||||
self.class_prompt = class_prompt
|
||||
self.save_sample_prompt = save_sample_prompt
|
||||
self.save_sample_template = save_sample_template
|
||||
self.instance_token = instance_token
|
||||
self.class_token = class_token
|
||||
self.num_class_images = num_class_images
|
||||
self.class_negative_prompt = class_negative_prompt
|
||||
self.class_guidance_scale = class_guidance_scale
|
||||
self.class_infer_steps = class_infer_steps
|
||||
self.save_sample_negative_prompt = save_sample_negative_prompt
|
||||
self.n_save_sample = n_save_sample
|
||||
self.sample_seed = sample_seed
|
||||
self.save_guidance_scale = save_guidance_scale
|
||||
self.save_infer_steps = save_infer_steps
|
||||
else:
|
||||
self.max_steps = input_dict["max_steps"] if "max_steps" in input_dict else -1
|
||||
self.instance_data_dir = input_dict["instance_data_dir"] if "instance_data_dir" in input_dict else ""
|
||||
self.class_data_dir = input_dict["class_data_dir"] if "class_data_dir" in input_dict else ""
|
||||
self.file_prompt_contents = input_dict["file_prompt_contents"] if "file_prompt_contents" in input_dict else "Description"
|
||||
self.instance_prompt = input_dict["instance_prompt"] if "instance_prompt" in input_dict else ""
|
||||
self.class_prompt = input_dict["class_prompt"] if "class_prompt" in input_dict else ""
|
||||
self.save_sample_prompt = input_dict["save_sample_prompt"] if "save_sample_prompt" in input_dict else ""
|
||||
self.save_sample_template = input_dict["save_sample_template"] if "save_sample_template" in input_dict else ""
|
||||
self.instance_token = input_dict["instance_token"] if "instance_token" in input_dict else ""
|
||||
self.class_token = input_dict["class_token"] if "class_token" in input_dict else ""
|
||||
self.num_class_images = input_dict["num_class_images"] if "num_class_images" in input_dict else 0
|
||||
self.class_negative_prompt = input_dict["class_negative_prompt"] if "class_negative_promt" in input_dict else ""
|
||||
self.class_guidance_scale = input_dict["class_guidance_scale"] if "class_guidance_scale" in input_dict else 7.5
|
||||
self.class_infer_steps = input_dict["class_infer_steps"] if "class_infer_steps" in input_dict else 60
|
||||
self.save_sample_negative_prompt = input_dict["save_sample_negative_prompt"] if "save_sample_negative_prompt" in input_dict else ""
|
||||
self.n_save_sample = input_dict["n_save_sample"] if "n_save_samples" in input_dict else 1
|
||||
self.sample_seed = input_dict["sample_seed"] if "sample_seed" in input_dict else -1
|
||||
self.save_guidance_scale = input_dict["save_guidance_scale"] if "save_guidance_scale" in input_dict else 7.5
|
||||
self.save_infer_steps = input_dict["save_infer_steps"] if "save_infer_steps" in input_dict else 60
|
||||
|
||||
self_dict = {
|
||||
"max_steps": self.max_steps,
|
||||
"instance_data_dir": self.instance_data_dir,
|
||||
"class_data_dir": self.class_data_dir,
|
||||
"file_prompt_contents": self.file_prompt_contents,
|
||||
"instance_prompt": self.instance_prompt,
|
||||
"class_prompt": self.class_prompt,
|
||||
"save_sample_prompt": self.save_sample_prompt,
|
||||
"save_sample_template": self.save_sample_template,
|
||||
"instance_token": self.instance_token,
|
||||
"class_token": self.class_token,
|
||||
"num_class_images": self.num_class_images,
|
||||
"class_negative_prompt": self.class_negative_prompt,
|
||||
"class_guidance_scale": self.class_guidance_scale,
|
||||
"class_infer_steps": self.class_infer_steps,
|
||||
"save_sample_negative_prompt": self.save_sample_negative_prompt,
|
||||
"n_save_sample": self.n_save_sample,
|
||||
"sample_seed": self.sample_seed,
|
||||
"save_guidance_scale": self.save_guidance_scale,
|
||||
"save_infer_steps": self.save_infer_steps
|
||||
}
|
||||
super().__init__(self_dict)
|
||||
|
||||
def is_valid(self):
|
||||
if self.instance_data_dir is not None and self.instance_data_dir != "":
|
||||
if os.path.exists(self.instance_data_dir):
|
||||
return True
|
||||
return False
|
||||
|
|
@ -2,13 +2,19 @@ import json
|
|||
import os
|
||||
import traceback
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import sanitize_name
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
|
||||
from modules import images, shared
|
||||
|
||||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
|
||||
class DreamboothConfig:
|
||||
v2 = False
|
||||
scheduler = "ddim"
|
||||
lifetime_revision = 0
|
||||
initial_revision = 0
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "",
|
||||
|
|
@ -22,6 +28,7 @@ class DreamboothConfig:
|
|||
gradient_accumulation_steps: int = 1,
|
||||
gradient_checkpointing: bool = True,
|
||||
half_model: bool = False,
|
||||
has_ema: bool = False,
|
||||
hflip: bool = False,
|
||||
learning_rate: float = 0.00000172,
|
||||
lr_scheduler: str = 'constant',
|
||||
|
|
@ -39,9 +46,11 @@ class DreamboothConfig:
|
|||
resolution: int = 512,
|
||||
revision: int = 0,
|
||||
sample_batch_size: int = 1,
|
||||
save_class_txt: bool = False,
|
||||
save_embedding_every: int = 500,
|
||||
save_preview_every: int = 500,
|
||||
scale_lr: bool = False,
|
||||
scheduler: str = "ddim",
|
||||
src: str = "",
|
||||
train_batch_size: int = 1,
|
||||
train_text_encoder: bool = True,
|
||||
|
|
@ -49,6 +58,7 @@ class DreamboothConfig:
|
|||
use_concepts: bool = False,
|
||||
use_cpu: bool = False,
|
||||
use_ema: bool = True,
|
||||
v2: bool = False,
|
||||
c1_class_data_dir: str = "",
|
||||
c1_class_guidance_scale: float = 7.5,
|
||||
c1_class_infer_steps: int = 60,
|
||||
|
|
@ -106,11 +116,7 @@ class DreamboothConfig:
|
|||
c3_save_sample_negative_prompt: str = "",
|
||||
c3_save_sample_prompt: str = "",
|
||||
c3_save_sample_template: str = "",
|
||||
pretrained_model_name_or_path="",
|
||||
concepts_list=None,
|
||||
v2=None,
|
||||
scheduler= None,
|
||||
has_ema=False,
|
||||
**kwargs
|
||||
):
|
||||
if revision == "" or revision is None:
|
||||
|
|
@ -153,6 +159,7 @@ class DreamboothConfig:
|
|||
self.resolution = resolution
|
||||
self.revision = int(revision)
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.save_class_txt = save_class_txt
|
||||
self.save_embedding_every = save_embedding_every
|
||||
self.save_preview_every = save_preview_every
|
||||
self.scale_lr = scale_lr
|
||||
|
|
@ -237,6 +244,7 @@ class DreamboothConfig:
|
|||
"""
|
||||
Save the config file
|
||||
"""
|
||||
self.lifetime_revision = self.initial_revision + self.revision
|
||||
models_path = shared.cmd_opts.dreambooth_models_path
|
||||
if models_path == "" or models_path is None:
|
||||
models_path = os.path.join(shared.models_path, "dreambooth")
|
||||
|
|
@ -323,80 +331,3 @@ def from_file(model_name):
|
|||
return None
|
||||
|
||||
|
||||
class Concept(dict):
|
||||
def __init__(self, max_steps: int = -1, instance_data_dir: str = "", class_data_dir: str = "",
|
||||
file_prompt_contents: str = "Description", instance_prompt: str = "", class_prompt: str = "",
|
||||
save_sample_prompt: str = "", save_sample_template: str = "", instance_token: str = "",
|
||||
class_token: str = "", num_class_images: int = 0, class_negative_prompt: str = "",
|
||||
class_guidance_scale: float = 7.5, class_infer_steps: int = 60, save_sample_negative_prompt: str = "",
|
||||
n_save_sample: int = 1, sample_seed: int = -1, save_guidance_scale: float = 7.5,
|
||||
save_infer_steps: int = 60, input_dict=None):
|
||||
if input_dict is None:
|
||||
self.max_steps = max_steps
|
||||
self.instance_data_dir = instance_data_dir
|
||||
self.class_data_dir = class_data_dir
|
||||
self.file_prompt_contents = file_prompt_contents
|
||||
self.instance_prompt = instance_prompt
|
||||
self.class_prompt = class_prompt
|
||||
self.save_sample_prompt = save_sample_prompt
|
||||
self.save_sample_template = save_sample_template
|
||||
self.instance_token = instance_token
|
||||
self.class_token = class_token
|
||||
self.num_class_images = num_class_images
|
||||
self.class_negative_prompt = class_negative_prompt
|
||||
self.class_guidance_scale = class_guidance_scale
|
||||
self.class_infer_steps = class_infer_steps
|
||||
self.save_sample_negative_prompt = save_sample_negative_prompt
|
||||
self.n_save_sample = n_save_sample
|
||||
self.sample_seed = sample_seed
|
||||
self.save_guidance_scale = save_guidance_scale
|
||||
self.save_infer_steps = save_infer_steps
|
||||
else:
|
||||
self.max_steps = input_dict["max_steps"] if "max_steps" in input_dict else -1
|
||||
self.instance_data_dir = input_dict["instance_data_dir"] if "instance_data_dir" in input_dict else ""
|
||||
self.class_data_dir = input_dict["class_data_dir"] if "class_data_dir" in input_dict else ""
|
||||
self.file_prompt_contents = input_dict["file_prompt_contents"] if "file_prompt_contents" in input_dict else "Description"
|
||||
self.instance_prompt = input_dict["instance_prompt"] if "instance_prompt" in input_dict else ""
|
||||
self.class_prompt = input_dict["class_prompt"] if "class_prompt" in input_dict else ""
|
||||
self.save_sample_prompt = input_dict["save_sample_prompt"] if "save_sample_prompt" in input_dict else ""
|
||||
self.save_sample_template = input_dict["save_sample_template"] if "save_sample_template" in input_dict else ""
|
||||
self.instance_token = input_dict["instance_token"] if "instance_token" in input_dict else ""
|
||||
self.class_token = input_dict["class_token"] if "class_token" in input_dict else ""
|
||||
self.num_class_images = input_dict["num_class_images"] if "num_class_images" in input_dict else 0
|
||||
self.class_negative_prompt = input_dict["class_negative_prompt"] if "class_negative_promt" in input_dict else ""
|
||||
self.class_guidance_scale = input_dict["class_guidance_scale"] if "class_guidance_scale" in input_dict else 7.5
|
||||
self.class_infer_steps = input_dict["class_infer_steps"] if "class_infer_steps" in input_dict else 60
|
||||
self.save_sample_negative_prompt = input_dict["save_sample_negative_prompt"] if "save_sample_negative_prompt" in input_dict else ""
|
||||
self.n_save_sample = input_dict["n_save_sample"] if "n_save_samples" in input_dict else 1
|
||||
self.sample_seed = input_dict["sample_seed"] if "sample_seed" in input_dict else -1
|
||||
self.save_guidance_scale = input_dict["save_guidance_scale"] if "save_guidance_scale" in input_dict else 7.5
|
||||
self.save_infer_steps = input_dict["save_infer_steps"] if "save_infer_steps" in input_dict else 60
|
||||
|
||||
self_dict = {
|
||||
"max_steps": self.max_steps,
|
||||
"instance_data_dir": self.instance_data_dir,
|
||||
"class_data_dir": self.class_data_dir,
|
||||
"file_prompt_contents": self.file_prompt_contents,
|
||||
"instance_prompt": self.instance_prompt,
|
||||
"class_prompt": self.class_prompt,
|
||||
"save_sample_prompt": self.save_sample_prompt,
|
||||
"save_sample_template": self.save_sample_template,
|
||||
"instance_token": self.instance_token,
|
||||
"class_token": self.class_token,
|
||||
"num_class_images": self.num_class_images,
|
||||
"class_negative_prompt": self.class_negative_prompt,
|
||||
"class_guidance_scale": self.class_guidance_scale,
|
||||
"class_infer_steps": self.class_infer_steps,
|
||||
"save_sample_negative_prompt": self.save_sample_negative_prompt,
|
||||
"n_save_sample": self.n_save_sample,
|
||||
"sample_seed": self.sample_seed,
|
||||
"save_guidance_scale": self.save_guidance_scale,
|
||||
"save_infer_steps": self.save_infer_steps
|
||||
}
|
||||
super().__init__(self_dict)
|
||||
|
||||
def is_valid(self):
|
||||
if self.instance_data_dir is not None and self.instance_data_dir != "":
|
||||
if os.path.exists(self.instance_data_dir):
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -13,10 +15,6 @@ from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, printi,
|
|||
reload_system_models
|
||||
from modules import shared
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
|
|
@ -42,37 +40,8 @@ unet_conversion_map_resnet = [
|
|||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
UNET_PARAMS_MODEL_CHANNELS = 320
|
||||
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
||||
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
||||
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
||||
UNET_PARAMS_IN_CHANNELS = 4
|
||||
UNET_PARAMS_OUT_CHANNELS = 4
|
||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||
UNET_PARAMS_CONTEXT_DIM = 768
|
||||
UNET_PARAMS_NUM_HEADS = 8
|
||||
unet_params = {
|
||||
"model_channels": 320,
|
||||
"channel_mult": [1, 2, 4, 4],
|
||||
"attention_resolutions": [4, 2, 1],
|
||||
"image_size": 32, # unused
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_res_blocks": 2,
|
||||
"context_dim": 768,
|
||||
"num_heads": 8
|
||||
}
|
||||
unet_v2_params = unet_params.copy()
|
||||
unet_v2_params["num_heads"] = [5, 10, 20, 20]
|
||||
unet_v2_params["attention_head_dim"] = [5, 10, 20, 20]
|
||||
unet_v2_params["context_dim"] = 1024
|
||||
VAE_PARAMS_Z_CHANNELS = 4
|
||||
VAE_PARAMS_RESOLUTION = 256
|
||||
VAE_PARAMS_IN_CHANNELS = 3
|
||||
VAE_PARAMS_OUT_CH = 3
|
||||
VAE_PARAMS_CH = 128
|
||||
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
||||
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(4):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
|
|
@ -121,15 +90,6 @@ for j in range(2):
|
|||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def conv_transformer_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
|
|
@ -227,139 +187,85 @@ def convert_vae_state_dict(vae_state_dict):
|
|||
|
||||
|
||||
# =========================#
|
||||
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
conversion_map_layer = []
|
||||
for q in range(4):
|
||||
for r in range(2):
|
||||
hfd_res_prefix = f"down_blocks.{q}.resnets.{r}."
|
||||
sdd_res_prefix = f"input_blocks.{3 * q + r + 1}.0."
|
||||
conversion_map_layer.append((sdd_res_prefix, hfd_res_prefix))
|
||||
if q < 3:
|
||||
hfd_atn_prefix = f"down_blocks.{q}.attentions.{r}."
|
||||
sdd_atn_prefix = f"input_blocks.{3 * q + r + 1}.1."
|
||||
conversion_map_layer.append((sdd_atn_prefix, hfd_atn_prefix))
|
||||
for r in range(3):
|
||||
hfu_res_prefix = f"up_blocks.{q}.resnets.{r}."
|
||||
sdu_res_prefix = f"output_blocks.{3 * q + r}.0."
|
||||
conversion_map_layer.append((sdu_res_prefix, hfu_res_prefix))
|
||||
if q > 0:
|
||||
hfu_atn_prefix = f"up_blocks.{q}.attentions.{r}."
|
||||
sdu_atn_prefix = f"output_blocks.{3 * q + r}.1."
|
||||
conversion_map_layer.append((sdu_atn_prefix, hfu_atn_prefix))
|
||||
if q < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hfd_prefix = f"down_blocks.{q}.downsamplers.0.conv."
|
||||
sdd_prefix = f"input_blocks.{3 * (q + 1)}.0.op."
|
||||
conversion_map_layer.append((sdd_prefix, hfd_prefix))
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hfu_prefix = f"up_blocks.{q}.upsamplers.0."
|
||||
sdu_prefix = f"output_blocks.{3 * q + 2}.{1 if q == 0 else 2}."
|
||||
conversion_map_layer.append((sdu_prefix, hfu_prefix))
|
||||
|
||||
hfm_atn_prefix = "mid_block.attentions.0."
|
||||
sdm_atn_prefix = "middle_block.1."
|
||||
conversion_map_layer.append((sdm_atn_prefix, hfm_atn_prefix))
|
||||
|
||||
for r in range(2):
|
||||
hfm_res_prefix = f"mid_block.resnets.{r}."
|
||||
sdm_res_prefix = f"middle_block.{2 * r}."
|
||||
conversion_map_layer.append((sdm_res_prefix, hfm_res_prefix))
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
('resblocks.', 'text_model.encoder.layers.'),
|
||||
('ln_1', 'layer_norm1'),
|
||||
('ln_2', 'layer_norm2'),
|
||||
('.c_fc.', '.fc1.'),
|
||||
('.c_proj.', '.fc2.'),
|
||||
('.attn', '.self_attn'),
|
||||
('ln_final.', 'transformer.text_model.final_layer_norm.'),
|
||||
('token_embedding.weight', 'transformer.text_model.embeddings.token_embedding.weight'),
|
||||
('positional_embedding', 'transformer.text_model.embeddings.position_embedding.weight')
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {'q': 0, 'k': 1, 'v': 2}
|
||||
|
||||
if v2:
|
||||
conv_transformer_to_linear(new_state_dict)
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if k.endswith('.self_attn.q_proj.weight') or k.endswith('.self_attn.k_proj.weight') or k.endswith(
|
||||
'.self_attn.v_proj.weight'):
|
||||
k_pre = k[:-len('.q_proj.weight')]
|
||||
k_code = k[-len('q_proj.weight')]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if k.endswith('.self_attn.q_proj.bias') or k.endswith('.self_attn.k_proj.bias') or k.endswith(
|
||||
'.self_attn.v_proj.bias'):
|
||||
k_pre = k[:-len('.q_proj.bias')]
|
||||
k_code = k[-len('q_proj.bias')]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
# if relabelled_key != k:
|
||||
# print(f"{k} -> {relabelled_key}")
|
||||
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + '.in_proj_weight'] = torch.cat(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + '.in_proj_bias'] = torch.cat(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
# pretty much a no-op
|
||||
|
||||
|
||||
def convert_text_enc_state_dict(text_enc_dict):
|
||||
def convert_text_enc_state_dict(text_enc_dict: dict[str, torch.Tensor]):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
|
||||
def convert_key(conv_key):
|
||||
if ".position_ids" in conv_key:
|
||||
return None
|
||||
|
||||
# common
|
||||
conv_key = conv_key.replace("text_model.encoder.", "transformer.")
|
||||
conv_key = conv_key.replace("text_model.", "")
|
||||
if "layers" in conv_key:
|
||||
# resblocks conversion
|
||||
conv_key = conv_key.replace(".layers.", ".resblocks.")
|
||||
if ".layer_norm" in conv_key:
|
||||
conv_key = conv_key.replace(".layer_norm", ".ln_")
|
||||
elif ".mlp." in conv_key:
|
||||
conv_key = conv_key.replace(".fc1.", ".c_fc.")
|
||||
conv_key = conv_key.replace(".fc2.", ".c_proj.")
|
||||
elif '.self_attn.out_proj' in conv_key:
|
||||
conv_key = conv_key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
||||
elif '.self_attn.' in conv_key:
|
||||
conv_key = None
|
||||
else:
|
||||
raise ValueError(f"unexpected key in DiffUsers model: {conv_key}")
|
||||
elif '.position_embedding' in conv_key:
|
||||
conv_key = conv_key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||
elif '.token_embedding' in conv_key:
|
||||
conv_key = conv_key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||
elif 'final_layer_norm' in conv_key:
|
||||
conv_key = conv_key.replace("final_layer_norm", "ln_final")
|
||||
return conv_key
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
new_sd = {}
|
||||
for key in keys:
|
||||
new_key = convert_key(key)
|
||||
if new_key is None:
|
||||
continue
|
||||
new_sd[new_key] = checkpoint[key]
|
||||
|
||||
for key in keys:
|
||||
if 'layers' in key and 'q_proj' in key:
|
||||
key_q = key
|
||||
key_k = key.replace("q_proj", "k_proj")
|
||||
key_v = key.replace("q_proj", "v_proj")
|
||||
|
||||
value_q = checkpoint[key_q]
|
||||
value_k = checkpoint[key_k]
|
||||
value_v = checkpoint[key_v]
|
||||
value = torch.cat([value_q, value_k, value_v])
|
||||
|
||||
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||
new_sd[new_key] = value
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, reload_models=False):
|
||||
def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, reload_models=True):
|
||||
"""
|
||||
|
||||
@param model_name: The model name to compile
|
||||
@param half: Use FP16 when compiling the model
|
||||
@param use_subdir: The model will be saved to a subdirectory of the checkpoints folder
|
||||
@param reload_models: Whether to reload the system list of checkpoints.
|
||||
@return: status: What happened, path: Checkpoint path
|
||||
"""
|
||||
unload_system_models()
|
||||
|
|
@ -376,57 +282,76 @@ def compile_checkpoint(model_name: str, half: bool, use_subdir: bool = False, re
|
|||
models_path = ckpt_dir
|
||||
|
||||
config = from_file(model_name)
|
||||
try:
|
||||
if "use_subdir" in config.__dict__:
|
||||
use_subdir = config["use_subdir"]
|
||||
except:
|
||||
print("Yeah, we can't use dict to find config values.")
|
||||
if "use_subdir" in config.__dict__:
|
||||
use_subdir = config["use_subdir"]
|
||||
|
||||
v2 = config.v2
|
||||
total_steps = config.revision
|
||||
|
||||
if use_subdir:
|
||||
os.makedirs(os.path.join(models_path, model_name))
|
||||
out_file = os.path.join(models_path, model_name, f"{model_name}_{total_steps}.ckpt")
|
||||
checkpoint_path = os.path.join(models_path, model_name, f"{model_name}_{total_steps}.ckpt")
|
||||
else:
|
||||
out_file = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
|
||||
checkpoint_path = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
|
||||
|
||||
model_path = config.pretrained_model_name_or_path
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
printi("Converting unet...")
|
||||
# Convert the UNet model
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
#unet_state_dict = convert_unet_state_dict_to_sd(v2, unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
printi("Converting vae...")
|
||||
# Convert the VAE model
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
printi("Converting text encoder...")
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
#text_enc_dict = convert_text_enc_state_dict(text_enc_dict) if not v2 else convert_text_encoder_state_dict_to_sd_v2(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
printi("Compiling new state dict...")
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
if half:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
try:
|
||||
printi("Converting unet...")
|
||||
# Convert the UNet model
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
# unet_state_dict = convert_unet_state_dict_to_sd(v2, unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
printi("Converting vae...")
|
||||
# Convert the VAE model
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
printi("Converting text encoder...")
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
print("Converting text enc dict for V2 model.")
|
||||
# Need to add the tag 'transformer' in advance, so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
if not config.v2:
|
||||
config.v2 = True
|
||||
config.save()
|
||||
v2 = True
|
||||
else:
|
||||
print("Converting text enc dict for V1 model.")
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
if half:
|
||||
print("Halving model.")
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
state_dict = {"global_step": config.revision, "state_dict": state_dict}
|
||||
printi(f"Saving checkpoint to {checkpoint_path}...")
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
if v2:
|
||||
cfg_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "v2-inference-v.yaml")
|
||||
cfg_dest = checkpoint_path.replace(".ckpt", ".yaml")
|
||||
print(f"Copying config file to {cfg_dest}")
|
||||
shutil.copyfile(cfg_file, cfg_dest)
|
||||
except Exception as e:
|
||||
print("Exception compiling checkpoint!")
|
||||
traceback.print_exc()
|
||||
return f"Exception compiling: {e}", ""
|
||||
|
||||
state_dict = {"state_dict": state_dict}
|
||||
new_ckpt = {'global_step': config.revision, 'state_dict': state_dict}
|
||||
printi(f"Saving checkpoint to {out_file}...")
|
||||
torch.save(new_ckpt, out_file)
|
||||
if v2:
|
||||
cfg_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "v2-inference-v.yaml")
|
||||
cfg_dest = out_file.replace(".ckpt", ".yaml")
|
||||
print(f"Copying config file to {cfg_dest}")
|
||||
shutil.copyfile(cfg_file, cfg_dest)
|
||||
try:
|
||||
del unet_state_dict
|
||||
del vae_state_dict
|
||||
|
|
|
|||
|
|
@ -2,24 +2,17 @@ import gc
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from PIL import features
|
||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers.utils import logging as dl
|
||||
from huggingface_hub import HfFolder, whoami
|
||||
from six import StringIO
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file, Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import reload_system_models, unload_system_models, printm
|
||||
from modules import paths, shared, devices, sd_models
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import reload_system_models, unload_system_models, printm, \
|
||||
isset, list_features, is_image
|
||||
from modules import shared, devices
|
||||
|
||||
try:
|
||||
cmd_dreambooth_models_path = shared.cmd_opts.dreambooth_models_path
|
||||
|
|
@ -27,7 +20,6 @@ except:
|
|||
cmd_dreambooth_models_path = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# define a Handler which writes DEBUG messages or higher to the sys.stderr
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console)
|
||||
|
|
@ -37,83 +29,13 @@ dl.set_verbosity_error()
|
|||
mem_record = {}
|
||||
|
||||
|
||||
def generate_sample_img(model_dir: str, save_sample_prompt: str, seed: str):
|
||||
print("Gensample?")
|
||||
if model_dir is None or model_dir == "":
|
||||
return "Please select a model."
|
||||
config = from_file(model_dir)
|
||||
unload_system_models()
|
||||
model_path = config.pretrained_model_name_or_path
|
||||
image = None
|
||||
if not os.path.exists(config.pretrained_model_name_or_path):
|
||||
print(f"Model path '{config.pretrained_model_name_or_path}' doesn't exist.")
|
||||
return f"Can't find diffusers model at {config.pretrained_model_name_or_path}.", None
|
||||
try:
|
||||
print(f"Loading model from {model_path}.")
|
||||
text_enc_model = CLIPTextModel.from_pretrained(config.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder", revision=config.revision)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
config.pretrained_model_name_or_path,
|
||||
text_encoder=text_enc_model,
|
||||
torch_dtype=torch.float16,
|
||||
revision=config.revision,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
)
|
||||
pipeline = pipeline.to(shared.device)
|
||||
pil_features = list_features()
|
||||
save_dir = os.path.join(shared.sd_path, "outputs", "dreambooth")
|
||||
db_model_path = config.model_dir
|
||||
if save_sample_prompt is None:
|
||||
msg = "Please provide a sample prompt."
|
||||
print(msg)
|
||||
return msg, None
|
||||
shared.state.textinfo = f"Generating preview image for model {db_model_path}..."
|
||||
# I feel like this might not actually be necessary...but what the heck.
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
seed = int(random.randrange(21474836147))
|
||||
g_cuda = torch.Generator(device=shared.device).manual_seed(seed)
|
||||
sample_dir = os.path.join(save_dir, "samples")
|
||||
os.makedirs(sample_dir, exist_ok=True)
|
||||
file_count = 0
|
||||
for x in Path(sample_dir).iterdir():
|
||||
if is_image(x, pil_features):
|
||||
file_count += 1
|
||||
shared.state.job_count = 1
|
||||
with torch.autocast("cuda"), torch.inference_mode():
|
||||
image = pipeline(save_sample_prompt,
|
||||
num_inference_steps=60,
|
||||
guidance_scale=7.5,
|
||||
scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085,
|
||||
beta_end=0.012),
|
||||
width=config.resolution,
|
||||
height=config.resolution,
|
||||
generator=g_cuda).images[0]
|
||||
|
||||
except:
|
||||
print("Exception generating sample!")
|
||||
traceback.print_exc()
|
||||
reload_system_models()
|
||||
return "Sample generated.", image
|
||||
|
||||
|
||||
|
||||
|
||||
# Borrowed from https://wandb.ai/psuraj/dreambooth/reports/Training-Stable-Diffusion-with-Dreambooth
|
||||
# --VmlldzoyNzk0NDc3#tl,dr; and https://www.reddit.com/r/StableDiffusion/comments/ybxv7h/good_dreambooth_formula/
|
||||
def training_wizard_person(
|
||||
model_dir
|
||||
):
|
||||
def training_wizard_person(model_dir):
|
||||
return training_wizard(
|
||||
model_dir,
|
||||
is_person=True)
|
||||
|
||||
|
||||
def training_wizard(
|
||||
model_dir,
|
||||
is_person=False
|
||||
):
|
||||
def training_wizard(model_dir, is_person=False):
|
||||
"""
|
||||
Calculate the number of steps based on our learning rate, return the following:
|
||||
status,
|
||||
|
|
@ -285,41 +207,6 @@ def performance_wizard():
|
|||
train_batch_size, train_text_encoder, use_8bit_adam, use_cpu, use_ema
|
||||
|
||||
|
||||
|
||||
|
||||
def dumb_safety(images, clip_input):
|
||||
return images, False
|
||||
|
||||
|
||||
def isset(val: str):
|
||||
return val is not None and val != "" and val != "*"
|
||||
|
||||
|
||||
def list_features():
|
||||
# Create buffer for pilinfo() to write into rather than stdout
|
||||
buffer = StringIO()
|
||||
features.pilinfo(out=buffer)
|
||||
pil_features = []
|
||||
# Parse and analyse lines
|
||||
for line in buffer.getvalue().splitlines():
|
||||
if "Extensions:" in line:
|
||||
ext_list = line.split(": ")[1]
|
||||
extensions = ext_list.split(", ")
|
||||
for extension in extensions:
|
||||
if extension not in pil_features:
|
||||
pil_features.append(extension)
|
||||
return pil_features
|
||||
|
||||
|
||||
def is_image(path: Path, feats=None):
|
||||
if feats is None:
|
||||
feats = []
|
||||
if not len(feats):
|
||||
feats = list_features()
|
||||
is_img = path.is_file() and path.suffix.lower() in feats
|
||||
return is_img
|
||||
|
||||
|
||||
def load_params(model_dir):
|
||||
data = from_file(model_dir)
|
||||
msg = ""
|
||||
|
|
@ -352,15 +239,41 @@ def load_params(model_dir):
|
|||
ui_dict[f"c{c_idx}_{key}"] = ui_concept[key]
|
||||
c_idx += 1
|
||||
ui_dict["db_status"] = msg
|
||||
ui_keys = ["db_adam_beta1", "db_adam_beta2", "db_adam_epsilon", "db_adam_weight_decay", "db_attention",
|
||||
"db_center_crop", "db_concepts_path", "db_gradient_accumulation_steps", "db_gradient_checkpointing",
|
||||
"db_half_model", "db_has_ema", "db_hflip", "db_learning_rate", "db_lr_scheduler", "db_lr_warmup_steps",
|
||||
"db_max_grad_norm", "db_max_token_length", "db_max_train_steps", "db_mixed_precision", "db_model_path",
|
||||
"db_not_cache_latents", "db_num_train_epochs", "db_pad_tokens", "db_pretrained_vae_name_or_path",
|
||||
"db_prior_loss_weight", "db_resolution", "db_revision", "db_sample_batch_size",
|
||||
"db_save_embedding_every", "db_save_preview_every", "db_scale_lr", "db_scheduler", "db_src",
|
||||
"db_train_batch_size", "db_train_text_encoder", "db_use_8bit_adam", "db_use_concepts", "db_use_cpu",
|
||||
"db_use_ema", "db_v2", "c1_class_data_dir", "c1_class_guidance_scale", "c1_class_infer_steps",
|
||||
ui_keys = ["db_adam_beta1",
|
||||
"db_adam_beta2",
|
||||
"db_adam_epsilon",
|
||||
"db_adam_weight_decay",
|
||||
"db_attention",
|
||||
"db_center_crop",
|
||||
"db_concepts_path",
|
||||
"db_gradient_accumulation_steps",
|
||||
"db_gradient_checkpointing",
|
||||
"db_half_model",
|
||||
"db_hflip",
|
||||
"db_learning_rate",
|
||||
"db_lr_scheduler",
|
||||
"db_lr_warmup_steps",
|
||||
"db_max_grad_norm",
|
||||
"db_max_token_length",
|
||||
"db_max_train_steps",
|
||||
"db_mixed_precision",
|
||||
"db_not_cache_latents",
|
||||
"db_num_train_epochs",
|
||||
"db_pad_tokens",
|
||||
"db_pretrained_vae_name_or_path",
|
||||
"db_prior_loss_weight",
|
||||
"db_resolution",
|
||||
"db_sample_batch_size",
|
||||
"db_save_class_txt",
|
||||
"db_save_embedding_every",
|
||||
"db_save_preview_every",
|
||||
"db_scale_lr",
|
||||
"db_train_batch_size",
|
||||
"db_train_text_encoder",
|
||||
"db_use_8bit_adam",
|
||||
"db_use_concepts",
|
||||
"db_use_cpu",
|
||||
"db_use_ema", "c1_class_data_dir", "c1_class_guidance_scale", "c1_class_infer_steps",
|
||||
"c1_class_negative_prompt", "c1_class_prompt", "c1_class_token", "c1_file_prompt_contents",
|
||||
"c1_instance_data_dir", "c1_instance_prompt", "c1_instance_token", "c1_max_steps", "c1_n_save_sample",
|
||||
"c1_num_class_images", "c1_sample_seed", "c1_save_guidance_scale", "c1_save_infer_steps",
|
||||
|
|
@ -389,7 +302,24 @@ def load_params(model_dir):
|
|||
return output
|
||||
|
||||
|
||||
def start_training(model_dir: str, imagic_only: bool):
|
||||
def load_model_params(model_dir):
|
||||
data = from_file(model_dir)
|
||||
msg = ""
|
||||
if data is None:
|
||||
print("Can't load config!")
|
||||
msg = f"Error loading config for model '{model_dir}'."
|
||||
return "", "", "", "", "", "", msg
|
||||
else:
|
||||
return data.model_dir, \
|
||||
data.revision, \
|
||||
"True" if data.v2 else "False", \
|
||||
"True" if data.has_ema else "False", \
|
||||
data.src, \
|
||||
data.scheduler, \
|
||||
""
|
||||
|
||||
|
||||
def start_training(model_dir: str, imagic_only: bool, use_subdir: bool):
|
||||
global mem_record
|
||||
if model_dir == "" or model_dir is None:
|
||||
print("Invalid model name.")
|
||||
|
|
@ -421,55 +351,37 @@ def start_training(model_dir: str, imagic_only: bool):
|
|||
if msg:
|
||||
shared.state.textinfo = msg
|
||||
print(msg)
|
||||
return msg, msg, 0, ""
|
||||
return msg, msg, 0, msg
|
||||
|
||||
# Clear memory and do "stuff" only after we've ensured all the things are right
|
||||
print("Starting Dreambooth training...")
|
||||
unload_system_models()
|
||||
total_steps = config.revision
|
||||
if imagic_only:
|
||||
shared.state.textinfo = "Initializing imagic training..."
|
||||
print(shared.state.textinfo)
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic
|
||||
mem_record = train_imagic(config, mem_record)
|
||||
else:
|
||||
shared.state.textinfo = "Initializing dreambooth training..."
|
||||
print(shared.state.textinfo)
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
|
||||
config, mem_record, msg = main(config, mem_record)
|
||||
if config.revision != total_steps:
|
||||
config.save()
|
||||
total_steps = config.revision
|
||||
|
||||
try:
|
||||
if imagic_only:
|
||||
shared.state.textinfo = "Initializing imagic training..."
|
||||
print(shared.state.textinfo)
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic
|
||||
mem_record = train_imagic(config, mem_record)
|
||||
else:
|
||||
shared.state.textinfo = "Initializing dreambooth training..."
|
||||
print(shared.state.textinfo)
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
|
||||
config, mem_record, msg = main(config, mem_record, use_subdir=use_subdir)
|
||||
if config.revision != total_steps:
|
||||
config.save()
|
||||
total_steps = config.revision
|
||||
res = f"Training {'interrupted' if shared.state.interrupted else 'finished'}. " \
|
||||
f"Total lifetime steps: {total_steps} \n"
|
||||
except Exception as e:
|
||||
res = f"Exception training model: {e}"
|
||||
pass
|
||||
|
||||
devices.torch_gc()
|
||||
gc.collect()
|
||||
printm("Training completed, reloading SD Model.")
|
||||
print(f'Memory output: {mem_record}')
|
||||
reload_system_models()
|
||||
res = f"Training {'interrupted' if shared.state.interrupted else 'finished'}. " \
|
||||
f"Total lifetime steps: {total_steps} \n"
|
||||
print(f"Returning result: {res}")
|
||||
return res, "", total_steps, ""
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def save_checkpoint(model_name: str, total_steps: int):
|
||||
print(f"Successfully trained model for a total of {total_steps} steps, converting to ckpt.")
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir
|
||||
models_path = os.path.join(paths.models_path, "Stable-diffusion")
|
||||
if ckpt_dir is not None:
|
||||
models_path = ckpt_dir
|
||||
src_path = os.path.join(
|
||||
os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path, "dreambooth",
|
||||
model_name, "working")
|
||||
out_file = os.path.join(models_path, f"{model_name}_{total_steps}.ckpt")
|
||||
compile_checkpoint(model_name)
|
||||
sd_models.list_models()
|
||||
return res, res, total_steps, res
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import re
|
|||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from modules import shared
|
||||
|
|
@ -71,6 +72,41 @@ class FilenameTextGetter:
|
|||
return output
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
||||
|
||||
def __init__(self, prompt: str, num_samples: int, filename_texts, file_prompt_contents: str, class_token: str,
|
||||
instance_token: str):
|
||||
self.prompt = prompt
|
||||
self.instance_token = instance_token
|
||||
self.class_token = class_token
|
||||
self.num_samples = num_samples
|
||||
self.filename_texts = filename_texts
|
||||
self.file_prompt_contents = file_prompt_contents
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {"filename_text": self.filename_texts[index % len(self.filename_texts)] if len(
|
||||
self.filename_texts) > 0 else ""}
|
||||
prompt = example["filename_text"]
|
||||
if "Instance" in self.file_prompt_contents:
|
||||
class_token = self.class_token
|
||||
# If the token is already in the prompt, just remove the instance token, don't swap it
|
||||
class_tokens = [f"a {class_token}", f"the {class_token}", f"an {class_token}", class_token]
|
||||
for token in class_tokens:
|
||||
if token in prompt:
|
||||
prompt = prompt.replace(self.instance_token, "")
|
||||
else:
|
||||
prompt = prompt.replace(self.instance_token, self.class_token)
|
||||
|
||||
prompt = self.prompt.replace("[filewords]", prompt)
|
||||
example["prompt"] = prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
# Implementation from https://github.com/bmaltais/kohya_ss
|
||||
def encode_hidden_state(text_encoder: CLIPTextModel, input_ids, pad_tokens, b_size, max_token_length,
|
||||
tokenizer_max_length):
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
import gc
|
||||
import gradio as gr
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
import modules.sd_models
|
||||
|
|
@ -40,7 +40,7 @@ from diffusers import (
|
|||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
# HeunDiscreteScheduler, - Add after main diffusers is bumped on pypi.org
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
|
|
@ -466,7 +466,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
|
||||
# From Bmalthais
|
||||
# if v2:
|
||||
# linear_transformer_to_conv(new_checkpoint)
|
||||
# linear_transformer_to_conv(new_checkpoint)
|
||||
return new_checkpoint, has_ema
|
||||
|
||||
|
||||
|
|
@ -639,29 +639,30 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
|||
return text_model
|
||||
|
||||
|
||||
import re
|
||||
textenc_conversion_lst = [
|
||||
('cond_stage_model.model.positional_embedding',
|
||||
"text_model.embeddings.position_embedding.weight"),
|
||||
"text_model.embeddings.position_embedding.weight"),
|
||||
('cond_stage_model.model.token_embedding.weight',
|
||||
"text_model.embeddings.token_embedding.weight"),
|
||||
"text_model.embeddings.token_embedding.weight"),
|
||||
('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'),
|
||||
('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias')
|
||||
]
|
||||
textenc_conversion_map = {x[0]:x[1] for x in textenc_conversion_lst}
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
textenc_transformer_conversion_lst = [
|
||||
('resblocks.','text_model.encoder.layers.'),
|
||||
('ln_1','layer_norm1'),
|
||||
('ln_2','layer_norm2'),
|
||||
('.c_fc.','.fc1.'),
|
||||
('.c_proj.','.fc2.'),
|
||||
('.attn','.self_attn'),
|
||||
('ln_final.','transformer.text_model.final_layer_norm.'),
|
||||
('token_embedding.weight','transformer.text_model.embeddings.token_embedding.weight'),
|
||||
('positional_embedding','transformer.text_model.embeddings.position_embedding.weight')
|
||||
('resblocks.', 'text_model.encoder.layers.'),
|
||||
('ln_1', 'layer_norm1'),
|
||||
('ln_2', 'layer_norm2'),
|
||||
('.c_fc.', '.fc1.'),
|
||||
('.c_proj.', '.fc2.'),
|
||||
('.attn', '.self_attn'),
|
||||
('ln_final.', 'transformer.text_model.final_layer_norm.'),
|
||||
('token_embedding.weight', 'transformer.text_model.embeddings.token_embedding.weight'),
|
||||
('positional_embedding', 'transformer.text_model.embeddings.position_embedding.weight')
|
||||
]
|
||||
protected = {re.escape(x[0]):x[1] for x in textenc_transformer_conversion_lst}
|
||||
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
|
||||
|
|
@ -669,30 +670,30 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||
|
||||
text_model_dict = {}
|
||||
|
||||
d_model = int( checkpoint['cond_stage_model.model.text_projection'].shape[0] )
|
||||
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = \
|
||||
text_model.text_model.embeddings.get_buffer('position_ids')
|
||||
text_model.text_model.embeddings.get_buffer('position_ids')
|
||||
|
||||
for key in keys:
|
||||
if 'resblocks.23' in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
if 'resblocks.23' in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
continue
|
||||
if key in textenc_conversion_map:
|
||||
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||
if key.startswith("cond_stage_model.model.transformer."):
|
||||
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||
new_key = key[len("cond_stage_model.model.transformer."):]
|
||||
if new_key.endswith(".in_proj_weight"):
|
||||
new_key = new_key[:-len(".in_proj_weight")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key+'.q_proj.weight'] = checkpoint[key][:d_model,:]
|
||||
text_model_dict[new_key+'.k_proj.weight'] = checkpoint[key][d_model:d_model*2,:]
|
||||
text_model_dict[new_key+'.v_proj.weight'] = checkpoint[key][d_model*2:,:]
|
||||
text_model_dict[new_key + '.q_proj.weight'] = checkpoint[key][:d_model, :]
|
||||
text_model_dict[new_key + '.k_proj.weight'] = checkpoint[key][d_model:d_model * 2, :]
|
||||
text_model_dict[new_key + '.v_proj.weight'] = checkpoint[key][d_model * 2:, :]
|
||||
elif new_key.endswith(".in_proj_bias"):
|
||||
new_key = new_key[:-len(".in_proj_bias")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key+'.q_proj.bias'] = checkpoint[key][:d_model]
|
||||
text_model_dict[new_key+'.k_proj.bias'] = checkpoint[key][d_model:d_model*2]
|
||||
text_model_dict[new_key+'.v_proj.bias'] = checkpoint[key][d_model*2:]
|
||||
text_model_dict[new_key + '.q_proj.bias'] = checkpoint[key][:d_model]
|
||||
text_model_dict[new_key + '.k_proj.bias'] = checkpoint[key][d_model:d_model * 2]
|
||||
text_model_dict[new_key + '.v_proj.bias'] = checkpoint[key][d_model * 2:]
|
||||
else:
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
|
||||
|
|
@ -731,180 +732,220 @@ def extract_checkpoint(new_model_name: str, ckpt_path: str, scheduler_type="ddim
|
|||
db_config.src: The source checkpoint, if not from hub.
|
||||
db_has_ema: Whether the model had EMA weights and they were extracted. If weights were not present or
|
||||
you did not extract them and they were, this will be false.
|
||||
db_resolution: The resolution the model trains at.
|
||||
db_v2: Is this a V2 Model?
|
||||
status
|
||||
"""
|
||||
new_model_name = sanitize_name(new_model_name)
|
||||
print(f"new model URL is {new_model_url}")
|
||||
shared.state.job_no = 0
|
||||
checkpoint = None
|
||||
map_location = shared.device
|
||||
if shared.cmd_opts.ckptfix or shared.cmd_opts.medvram or shared.cmd_opts.lowvram:
|
||||
printm(f"Using CPU for extraction.")
|
||||
map_location = torch.device('cpu')
|
||||
db_config = None
|
||||
status = ""
|
||||
has_ema = False
|
||||
v2 = False
|
||||
model_dir = ""
|
||||
scheduler = ""
|
||||
src = ""
|
||||
revision = 0
|
||||
|
||||
# Try to determine if v1 or v2 model
|
||||
if not from_hub:
|
||||
checkpoint_info = modules.sd_models.get_closet_checkpoint_match(ckpt_path)
|
||||
# Needed for V2 models so we can create the right text encoder.
|
||||
|
||||
if checkpoint_info is None:
|
||||
print("Unable to find checkpoint file!")
|
||||
shared.state.job_no = 8
|
||||
return "", "", "", "", "", "", "", "Unable to find base checkpoint.", ""
|
||||
reset_safe = False
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
reset_safe = True
|
||||
shared.cmd_opts.disable_safe_unpickle = True
|
||||
|
||||
if not os.path.exists(checkpoint_info.filename):
|
||||
print("Unable to find checkpoint file!")
|
||||
shared.state.job_no = 8
|
||||
return "", "", "", "", "", "", "", "Unable to find base checkpoint.", ""
|
||||
try:
|
||||
new_model_name = sanitize_name(new_model_name)
|
||||
shared.state.job_no = 0
|
||||
checkpoint = None
|
||||
map_location = shared.device
|
||||
if shared.cmd_opts.ckptfix or shared.cmd_opts.medvram or shared.cmd_opts.lowvram:
|
||||
printm(f"Using CPU for extraction.")
|
||||
map_location = torch.device('cpu')
|
||||
|
||||
ckpt_path = checkpoint_info[0]
|
||||
# Try to determine if v1 or v2 model
|
||||
if not from_hub:
|
||||
printi("Loading model from checkpoint.")
|
||||
checkpoint_info = modules.sd_models.get_closet_checkpoint_match(ckpt_path)
|
||||
|
||||
printi("Loading checkpoint...")
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
# Todo: Decide if we should store this separately in the db_config and append it when re-compiling models.
|
||||
# global_step = checkpoint["global_step"]
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
if checkpoint_info is None:
|
||||
print("Unable to find checkpoint file!")
|
||||
shared.state.job_no = 8
|
||||
return "", "", 0, "", "", "", "", 512, "", "Unable to find base checkpoint."
|
||||
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if not os.path.exists(checkpoint_info.filename):
|
||||
print("Unable to find checkpoint file!")
|
||||
shared.state.job_no = 8
|
||||
return "", "", 0, "", "", "", "", 512, "", "Unable to find base checkpoint."
|
||||
|
||||
ckpt_path = checkpoint_info[0]
|
||||
|
||||
printi("Loading checkpoint...")
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
# Todo: Decide if we should store this separately in the db_config and append it when re-compiling models.
|
||||
# global_step = checkpoint["global_step"]
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
v2 = True
|
||||
else:
|
||||
v2 = False
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
v2 = True
|
||||
else:
|
||||
v2 = False
|
||||
if new_model_token == "" or new_model_token is None:
|
||||
msg = "Please provide a token to load models from the hub."
|
||||
print(msg)
|
||||
return "", "", 0, "", "", "", "", 512, "", msg
|
||||
printi("Loading model from hub.")
|
||||
v2 = new_model_url == "stabilityai/stable-diffusion-2"
|
||||
|
||||
else:
|
||||
if new_model_token == "" or new_model_token is None:
|
||||
msg = "Please provide a token to load models from the hub."
|
||||
return msg
|
||||
# Todo: Find out if there's a better way to do this
|
||||
v2 = new_model_url == "stabilityai/stable-diffusion-2"
|
||||
|
||||
if v2:
|
||||
prediction_type = "v_prediction"
|
||||
image_size = 768
|
||||
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
|
||||
"v2-inference-v.yaml")
|
||||
else:
|
||||
prediction_type = "epsilon"
|
||||
image_size = 512
|
||||
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
|
||||
"v1-inference.yaml")
|
||||
|
||||
db_config = DreamboothConfig(model_name=new_model_name, scheduler=scheduler_type, v2=v2,
|
||||
src=ckpt_path if not from_hub else new_model_url)
|
||||
print(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
|
||||
if v2:
|
||||
# All of the 2.0 models use OpenCLIP and all use DDPM scheduler by default.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
scheduler_type = "ddim"
|
||||
if v2:
|
||||
prediction_type = "v_prediction"
|
||||
image_size = 768
|
||||
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
|
||||
"v2-inference-v.yaml")
|
||||
else:
|
||||
scheduler_type = "pndm"
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
# elif scheduler_type == "heun":
|
||||
# scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
prediction_type = "epsilon"
|
||||
image_size = 512
|
||||
original_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs",
|
||||
"v1-inference.yaml")
|
||||
|
||||
if from_hub:
|
||||
print(f"Trying to create {new_model_name} from huggingface.co/{new_model_url}")
|
||||
printi("Loading model from hub.")
|
||||
pipe = DiffusionPipeline.from_pretrained(new_model_url, use_auth_token=new_model_token,
|
||||
scheduler=scheduler)
|
||||
printi("Model loaded.")
|
||||
shared.state.job_no = 7
|
||||
db_config = DreamboothConfig(model_name=new_model_name, scheduler=scheduler_type, v2=v2,
|
||||
src=ckpt_path if not from_hub else new_model_url, resolution=768 if v2 else 512)
|
||||
print(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||
|
||||
else:
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=ckpt_path, extract_ema=extract_ema
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
db_config.has_ema = has_ema
|
||||
db_config.save()
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif text_model_type == "FrozenCLIPEmbedder":
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
print("Creating scheduler...")
|
||||
if v2:
|
||||
# All of the 2.0 models use OpenCLIP and all use DDPM scheduler by default.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
scheduler_type = "ddim"
|
||||
else:
|
||||
scheduler_type = "pndm"
|
||||
db_config.scheduler = scheduler_type
|
||||
db_config.save()
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
if pipe is None:
|
||||
print("Unable to create da pipe!")
|
||||
status = "Unable to create pipeline for extraction."
|
||||
if from_hub:
|
||||
print(f"Trying to create {new_model_name} from huggingface.co/{new_model_url}")
|
||||
printi("Loading model from hub.")
|
||||
pipe = DiffusionPipeline.from_pretrained(new_model_url, use_auth_token=new_model_token,
|
||||
scheduler=scheduler, device_map=map_location)
|
||||
printi("Model loaded.")
|
||||
shared.state.job_no = 7
|
||||
|
||||
else:
|
||||
printi("Converting unet...")
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=ckpt_path, extract_ema=extract_ema
|
||||
)
|
||||
db_config.has_ema = has_ema
|
||||
db_config.save()
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
printi("Converting vae...")
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
printi("Converting text encoder...")
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif text_model_type == "FrozenCLIPEmbedder":
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=scheduler)
|
||||
except Exception as e:
|
||||
pipe = None
|
||||
status = f"Exception while extracting model: {e}"
|
||||
|
||||
if pipe is None or db_config is None:
|
||||
print("Pipeline or config is not set, unable to continue.")
|
||||
else:
|
||||
printi("Saving diffusers model...")
|
||||
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
||||
status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
||||
model_dir = db_config.model_dir
|
||||
revision = db_config.revision
|
||||
scheduler = db_config.scheduler
|
||||
src = db_config.src
|
||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||
for req_dir in required_dirs:
|
||||
full_path = os.path.join(db_config.pretrained_model_name_or_path, req_dir)
|
||||
if not os.path.exists(full_path):
|
||||
status = f"Missing model directory, removing model: {full_path}"
|
||||
shutil.rmtree(db_config.model_dir, ignore_errors=False, onerror=None)
|
||||
break
|
||||
|
||||
if reset_safe:
|
||||
shared.cmd_opts.disable_safe_unpickle = False
|
||||
|
||||
reload_system_models()
|
||||
printm(status, True)
|
||||
|
|
@ -912,10 +953,11 @@ def extract_checkpoint(new_model_name: str, ckpt_path: str, scheduler_type="ddim
|
|||
dirs = get_db_models()
|
||||
|
||||
return gr.Dropdown.update(choices=sorted(dirs), value=new_model_name), \
|
||||
db_config.model_dir, \
|
||||
db_config.revision, \
|
||||
db_config.scheduler, \
|
||||
db_config.src, \
|
||||
"True" if db_config.has_ema else "False", \
|
||||
"True" if db_config.v2 else "False", \
|
||||
model_dir, \
|
||||
revision, \
|
||||
scheduler, \
|
||||
src, \
|
||||
"True" if has_ema else "False", \
|
||||
"True" if v2 else "False", \
|
||||
db_config.resolution, \
|
||||
status
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# From shivam shiaro's repo, with "minimal" modification to hopefully allow for smoother updating?
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
|
|
@ -15,7 +14,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, \
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler, \
|
||||
DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
|
@ -28,10 +27,11 @@ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
|
|||
from extensions.sd_dreambooth_extension.dreambooth import xattention
|
||||
from extensions.sd_dreambooth_extension.dreambooth.SuperDataset import SuperDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import save_checkpoint, list_features, \
|
||||
is_image, printm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, encode_hidden_state
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, sanitize_name
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import printm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, encode_hidden_state, \
|
||||
PromptDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, sanitize_name, list_features, is_image
|
||||
from modules import shared
|
||||
|
||||
# Custom stuff
|
||||
|
|
@ -60,14 +60,15 @@ dl.set_verbosity_error()
|
|||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "text_encoder_cls":
|
||||
from transformers import text_encoder_cls
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return text_encoder_cls
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
|
|
@ -128,54 +129,6 @@ def parse_args(input_args=None):
|
|||
default=None,
|
||||
help="The prompt to specify images in the same class as provided instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_sample_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt used to generate sample outputs to save.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_sample_negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The negative prompt used to generate sample outputs to save.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_save_sample",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The number of samples to save.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="CFG for save sample.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_infer_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The number of inference steps for save sample.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The negative prompt used to generate sample outputs to save.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="CFG for save sample.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_infer_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The number of inference steps for save sample.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pad_tokens",
|
||||
default=False,
|
||||
|
|
@ -194,8 +147,8 @@ def parse_args(input_args=None):
|
|||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
||||
" sampled with class_prompt."
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -231,6 +184,7 @@ def parse_args(input_args=None):
|
|||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
|
|
@ -291,18 +245,15 @@ def parse_args(input_args=None):
|
|||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
|
||||
parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
|
||||
parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--not_cache_latents", action="store_true",
|
||||
|
|
@ -369,44 +320,20 @@ def parse_args(input_args=None):
|
|||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
if args.class_data_dir is not None:
|
||||
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
||||
|
||||
def __init__(self, prompt: str, num_samples: int, filename_texts, file_prompt_contents: str, class_token: str,
|
||||
instance_token: str):
|
||||
self.prompt = prompt
|
||||
self.instance_token = instance_token
|
||||
self.class_token = class_token
|
||||
self.num_samples = num_samples
|
||||
self.filename_texts = filename_texts
|
||||
self.file_prompt_contents = file_prompt_contents
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {"filename_text": self.filename_texts[index % len(self.filename_texts)] if len(
|
||||
self.filename_texts) > 0 else ""}
|
||||
prompt = example["filename_text"]
|
||||
if "Instance" in self.file_prompt_contents:
|
||||
class_token = self.class_token
|
||||
# If the token is already in the prompt, just remove the instance token, don't swap it
|
||||
class_tokens = [f"a {class_token}", f"the {class_token}", f"an {class_token}", class_token]
|
||||
for token in class_tokens:
|
||||
if token in prompt:
|
||||
prompt = prompt.replace(self.instance_token, "")
|
||||
else:
|
||||
prompt = prompt.replace(self.instance_token, self.class_token)
|
||||
|
||||
prompt = self.prompt.replace("[filewords]", prompt)
|
||||
example["prompt"] = prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
class LatentsDataset(Dataset):
|
||||
def __init__(self, latents_cache, text_encoder_cache, concepts_cache):
|
||||
self.latents_cache = latents_cache
|
||||
|
|
@ -449,7 +376,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict, str]:
|
||||
def main(args: DreamboothConfig, memory_record, use_subdir) -> tuple[DreamboothConfig, dict, str]:
|
||||
global with_prior
|
||||
text_encoder = None
|
||||
args.tokenizer_name = None
|
||||
|
|
@ -540,7 +467,6 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
|
||||
shared.state.job_count = num_new_images
|
||||
shared.state.job_no = 0
|
||||
save_txt = "[filewords]" in concept.class_prompt
|
||||
filename_texts = [text_getter.read_text(x) for x in Path(concept.instance_data_dir).iterdir() if
|
||||
is_image(x, pil_features)]
|
||||
sample_dataset = PromptDataset(concept.class_prompt, num_new_images, filename_texts,
|
||||
|
|
@ -570,15 +496,17 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
logger.debug("Generation canceled.")
|
||||
shared.state.textinfo = "Training canceled."
|
||||
return args, mem_record, "Training canceled."
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
if args.save_class_txt:
|
||||
image_base = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
else:
|
||||
image_base = example["prompt"]
|
||||
image_filename = class_images_dir / f"{generated_images + cur_class_images}-" \
|
||||
f"{hash_image}.jpg"
|
||||
f"{image_base}.jpg"
|
||||
image.save(image_filename)
|
||||
if save_txt:
|
||||
txt_filename = class_images_dir / f"{generated_images + cur_class_images}-" \
|
||||
f"{hash_image}.txt "
|
||||
if args.save_class_txt:
|
||||
txt_filename = image_filename.replace(".jpg", ".txt")
|
||||
with open(txt_filename, "w", encoding="utf8") as file:
|
||||
file.write(example["filename_text"][i] + "\n")
|
||||
file.write(example["prompt"])
|
||||
shared.state.job_no += 1
|
||||
generated_images += 1
|
||||
shared.state.textinfo = f"Class image {generated_images}/{num_new_images}, " \
|
||||
|
|
@ -597,15 +525,20 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name,
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
os.path.join(args.pretrained_model_name_or_path, "tokenizer"),
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
|
|
@ -658,6 +591,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
use_adam = True
|
||||
except Exception as a:
|
||||
logger.warning(f"Exception importing 8bit adam: {a}")
|
||||
traceback.print_exc()
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
||||
|
|
@ -670,7 +604,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDIMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
def cleanup_memory():
|
||||
try:
|
||||
|
|
@ -748,6 +682,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
return output
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
#train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
|
||||
)
|
||||
# Move text_encoder and VAE to GPU.
|
||||
|
|
@ -890,7 +825,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
pred_type = "epsilon"
|
||||
if args.v2:
|
||||
pred_type = "v_prediction"
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", steps_offset=1,
|
||||
clip_sample=False, set_alpha_to_one=False, prediction_type=pred_type)
|
||||
|
||||
s_pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
|
@ -911,7 +846,8 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
shared.state.textinfo = f"Saving checkpoint at step {args.revision}..."
|
||||
try:
|
||||
s_pipeline.save_pretrained(args.pretrained_model_name_or_path)
|
||||
save_checkpoint(args.model_name, args.revision)
|
||||
compile_checkpoint(args.model_name, half=args.half_model, use_subdir=use_subdir,
|
||||
reload_models=False)
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception saving checkpoint/model: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
@ -965,6 +901,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
|
||||
training_complete = False
|
||||
msg = ""
|
||||
weights_saved = False
|
||||
for epoch in range(args.num_train_epochs):
|
||||
if training_complete:
|
||||
break
|
||||
|
|
@ -973,6 +910,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
if args.train_text_encoder and text_encoder is not None:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
weights_saved = False
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
|
|
@ -1052,12 +990,8 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
args.revision += 1
|
||||
shared.state.job_no = global_step
|
||||
|
||||
training_complete = global_step >= args.max_train_steps or shared.state.interrupted
|
||||
|
||||
if global_step > 0:
|
||||
save_img = args.save_preview_every and not global_step % args.save_preview_every
|
||||
save_model = args.save_embedding_every and not global_step % args.save_embedding_every
|
||||
|
|
@ -1067,6 +1001,7 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
if save_img or save_model:
|
||||
save_weights()
|
||||
args.save()
|
||||
weights_saved = True
|
||||
shared.state.job_count = args.max_train_steps
|
||||
if shared.state.interrupted:
|
||||
training_complete = True
|
||||
|
|
@ -1086,11 +1021,22 @@ def main(args: DreamboothConfig, memory_record) -> tuple[DreamboothConfig, dict,
|
|||
f" total."
|
||||
|
||||
break
|
||||
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
args.revision += 1
|
||||
shared.state.job_no = global_step
|
||||
|
||||
training_complete = global_step >= args.max_train_steps or shared.state.interrupted
|
||||
accelerator.wait_for_everyone()
|
||||
if not args.not_cache_latents:
|
||||
train_dataset, train_dataloader = cache_latents(enc_vae=vae, orig_dataset=gen_dataset)
|
||||
if training_complete:
|
||||
if not weights_saved:
|
||||
save_img = True
|
||||
save_model = True
|
||||
save_weights()
|
||||
args.save()
|
||||
msg = f"Training completed, total steps: {args.revision}"
|
||||
break
|
||||
except Exception as m:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from tqdm.auto import tqdm
|
|||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import list_features, is_image, printm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import printm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import AverageMeter
|
||||
from modules import shared
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,20 @@
|
|||
import gc
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import features
|
||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from huggingface_hub import HfFolder, whoami
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, PromptDataset
|
||||
from modules import shared, paths
|
||||
|
||||
try:
|
||||
|
|
@ -92,3 +104,160 @@ def reload_system_models():
|
|||
if shared.sd_model is not None:
|
||||
shared.sd_model.to(shared.device)
|
||||
printm("Restored system models.")
|
||||
|
||||
|
||||
def debug_prompts(model_dir):
|
||||
from extensions.sd_dreambooth_extension.dreambooth.SuperDataset import SuperDataset
|
||||
if model_dir is None or model_dir == "":
|
||||
return "Please select a model."
|
||||
config = from_file(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
os.path.join(config.pretrained_model_name_or_path, "tokenizer"),
|
||||
revision=config.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
train_dataset = SuperDataset(
|
||||
concepts_list=config.concepts_list,
|
||||
tokenizer=tokenizer,
|
||||
size=config.resolution,
|
||||
center_crop=config.center_crop,
|
||||
lifetime_steps=config.revision,
|
||||
pad_tokens=config.pad_tokens,
|
||||
hflip=config.hflip,
|
||||
max_token_length=config.max_token_length
|
||||
)
|
||||
|
||||
output = {"instance_prompts": [], "existing_class_prompts": [], "new_class_prompts": [], "sample_prompts": []}
|
||||
|
||||
for i in range(train_dataset.__len__()):
|
||||
item = train_dataset.__getitem__(i)
|
||||
output["instance_prompts"].append(item["instance_prompt"])
|
||||
output["existing_class_prompts"].append(item["class_prompt"])
|
||||
output["sample_prompts"] = train_dataset.get_sample_prompts()
|
||||
|
||||
for concept in config.concepts_list:
|
||||
text_getter = FilenameTextGetter()
|
||||
c_idx = 0
|
||||
class_images_dir = Path(concept["class_data_dir"])
|
||||
if class_images_dir == "" or class_images_dir is None or class_images_dir == shared.script_path:
|
||||
class_images_dir = os.path.join(config.model_dir, f"classifiers_{c_idx}")
|
||||
print(f"Class image dir is not set, defaulting to {class_images_dir}")
|
||||
class_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
cur_class_images = 0
|
||||
iterfiles = 0
|
||||
pil_features = list_features()
|
||||
for x in class_images_dir.iterdir():
|
||||
iterfiles += 1
|
||||
if is_image(x, pil_features):
|
||||
cur_class_images += 1
|
||||
if cur_class_images < concept.num_class_images:
|
||||
num_new_images = concept.num_class_images - cur_class_images
|
||||
filename_texts = [text_getter.read_text(x) for x in Path(concept.instance_data_dir).iterdir() if
|
||||
is_image(x, pil_features)]
|
||||
sample_dataset = PromptDataset(concept.class_prompt, num_new_images, filename_texts,
|
||||
concept.file_prompt_contents, concept.class_token,
|
||||
concept.instance_token)
|
||||
for i in range(sample_dataset):
|
||||
output["new_class_prompts"].append(sample_dataset.__getitem__(i)["prompt"])
|
||||
c_idx += 1
|
||||
|
||||
return json.dumps(output)
|
||||
|
||||
|
||||
def generate_sample_img(model_dir: str, save_sample_prompt: str, seed: str):
|
||||
if model_dir is None or model_dir == "":
|
||||
return "Please select a model."
|
||||
config = from_file(model_dir)
|
||||
unload_system_models()
|
||||
model_path = config.pretrained_model_name_or_path
|
||||
image = None
|
||||
if not os.path.exists(config.pretrained_model_name_or_path):
|
||||
print(f"Model path '{config.pretrained_model_name_or_path}' doesn't exist.")
|
||||
return f"Can't find diffusers model at {config.pretrained_model_name_or_path}.", None
|
||||
try:
|
||||
print(f"Loading model from {model_path}.")
|
||||
text_enc_model = CLIPTextModel.from_pretrained(config.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder", revision=config.revision)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
config.pretrained_model_name_or_path,
|
||||
text_encoder=text_enc_model,
|
||||
torch_dtype=torch.float16,
|
||||
revision=config.revision,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
)
|
||||
pipeline = pipeline.to(shared.device)
|
||||
pil_features = list_features()
|
||||
save_dir = os.path.join(shared.sd_path, "outputs", "dreambooth")
|
||||
db_model_path = config.model_dir
|
||||
if save_sample_prompt is None:
|
||||
msg = "Please provide a sample prompt."
|
||||
print(msg)
|
||||
return msg, None
|
||||
shared.state.textinfo = f"Generating preview image for model {db_model_path}..."
|
||||
# I feel like this might not actually be necessary...but what the heck.
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
seed = int(random.randrange(21474836147))
|
||||
g_cuda = torch.Generator(device=shared.device).manual_seed(seed)
|
||||
sample_dir = os.path.join(save_dir, "samples")
|
||||
os.makedirs(sample_dir, exist_ok=True)
|
||||
file_count = 0
|
||||
for x in Path(sample_dir).iterdir():
|
||||
if is_image(x, pil_features):
|
||||
file_count += 1
|
||||
shared.state.job_count = 1
|
||||
with torch.autocast("cuda"), torch.inference_mode():
|
||||
image = pipeline(save_sample_prompt,
|
||||
num_inference_steps=60,
|
||||
guidance_scale=7.5,
|
||||
scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085,
|
||||
beta_end=0.012),
|
||||
width=config.resolution,
|
||||
height=config.resolution,
|
||||
generator=g_cuda).images[0]
|
||||
|
||||
except:
|
||||
print("Exception generating sample!")
|
||||
traceback.print_exc()
|
||||
reload_system_models()
|
||||
return "Sample generated.", image
|
||||
|
||||
|
||||
def isset(val: str):
|
||||
return val is not None and val != "" and val != "*"
|
||||
|
||||
|
||||
def list_features():
|
||||
# Create buffer for pilinfo() to write into rather than stdout
|
||||
buffer = StringIO()
|
||||
features.pilinfo(out=buffer)
|
||||
pil_features = []
|
||||
# Parse and analyse lines
|
||||
for line in buffer.getvalue().splitlines():
|
||||
if "Extensions:" in line:
|
||||
ext_list = line.split(": ")[1]
|
||||
extensions = ext_list.split(", ")
|
||||
for extension in extensions:
|
||||
if extension not in pil_features:
|
||||
pil_features.append(extension)
|
||||
return pil_features
|
||||
|
||||
|
||||
def is_image(path: Path, feats=None):
|
||||
if feats is None:
|
||||
feats = []
|
||||
if not len(feats):
|
||||
feats = list_features()
|
||||
is_img = path.is_file() and path.suffix.lower() in feats
|
||||
return is_img
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
|
|
|||
55
install.py
55
install.py
|
|
@ -6,9 +6,8 @@ import sys
|
|||
import sysconfig
|
||||
|
||||
import git
|
||||
from launch import run
|
||||
|
||||
from modules.paths import script_path
|
||||
from launch import run
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
|
|
@ -27,19 +26,21 @@ def check_versions():
|
|||
splits = line.split("==")
|
||||
if len(splits) == 2:
|
||||
key = splits[0]
|
||||
if "torch" not in key:
|
||||
if "diffusers" in key:
|
||||
key = "diffusers"
|
||||
reqs_dict[key] = splits[1].replace("\n", "").strip()
|
||||
|
||||
checks = ["bitsandbytes", "diffusers", "transformers", "xformers"]
|
||||
reqs_dict[key] = splits[1].replace("\n", "").strip()
|
||||
# print(f"Reqs dict: {reqs_dict}")
|
||||
reqs_dict["diffusers[torch]"] = "0.10.0.dev0"
|
||||
checks = ["bitsandbytes", "diffusers[torch]", "transformers", "xformers", "torch", "torchvision"]
|
||||
for check in checks:
|
||||
if check == "diffusers[torch]":
|
||||
il_check = "diffusers"
|
||||
else:
|
||||
il_check = check
|
||||
check_ver = "N/A"
|
||||
status = "[ ]"
|
||||
try:
|
||||
check_available = importlib.util.find_spec(check) is not None
|
||||
check_available = importlib.util.find_spec(il_check) is not None
|
||||
if check_available:
|
||||
check_ver = importlib_metadata.version(check)
|
||||
check_ver = importlib_metadata.version(il_check)
|
||||
if check in reqs_dict:
|
||||
req_version = reqs_dict[check]
|
||||
if str(check_ver) == str(req_version):
|
||||
|
|
@ -55,31 +56,41 @@ def check_versions():
|
|||
print(f"{status} {check} version {check_ver} installed.")
|
||||
|
||||
|
||||
dreambooth_skip_install = os.environ.get('DREAMBOOTH_SKIP_INSTALL', False)
|
||||
|
||||
if not dreambooth_skip_install:
|
||||
name = "Dreambooth"
|
||||
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...",
|
||||
f"Couldn't install {name} requirements.")
|
||||
|
||||
|
||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
repo = git.Repo(base_dir)
|
||||
revision = repo.rev_parse("HEAD")
|
||||
print(f"Dreambooth revision is {revision}")
|
||||
check_versions()
|
||||
app_repo = git.Repo(os.path.join(base_dir, "..", ".."))
|
||||
app_revision = app_repo.rev_parse("HEAD")
|
||||
print("#######################################################################################################")
|
||||
print("Initializing Dreambooth")
|
||||
print("If submitting an issue on github, please provide the below text for debugging purposes:")
|
||||
print("")
|
||||
print(f"Python revision: {sys.version}")
|
||||
print(f"Dreambooth revision: {revision}")
|
||||
print(f"SD-WebUI revision: {app_revision}")
|
||||
print("")
|
||||
dreambooth_skip_install = os.environ.get('DREAMBOOTH_SKIP_INSTALL', False)
|
||||
|
||||
if not dreambooth_skip_install:
|
||||
name = "Dreambooth"
|
||||
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...",
|
||||
f"Couldn't install {name} requirements.")
|
||||
|
||||
# Check for "different" B&B Files and copy only if necessary
|
||||
if os.name == "nt":
|
||||
python = sys.executable
|
||||
bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bitsandbytes_windows")
|
||||
bnb_dest = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
|
||||
printed = False
|
||||
filecmp.clear_cache()
|
||||
copied = False
|
||||
for file in os.listdir(bnb_src):
|
||||
src_file = os.path.join(bnb_src, file)
|
||||
if file == "main.py":
|
||||
if file == "main.py" or file == "paths.py":
|
||||
dest = os.path.join(bnb_dest, "cuda_setup")
|
||||
else:
|
||||
dest = bnb_dest
|
||||
dest_file = os.path.join(dest, file)
|
||||
shutil.copy2(src_file, dest)
|
||||
|
||||
check_versions()
|
||||
print("#######################################################################################################")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ albumentations==1.3.0
|
|||
basicsr==1.4.2
|
||||
bitsandbytes==0.35.0
|
||||
clean-fid==0.1.29
|
||||
diffusers[torch]==0.9.0
|
||||
git+https://github.com/huggingface/diffusers#egg=diffusers[torch]
|
||||
einops==0.4.1
|
||||
fairscale==0.4.9
|
||||
fastapi==0.87.0
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@ import gradio as gr
|
|||
|
||||
from extensions.sd_dreambooth_extension.dreambooth import dreambooth
|
||||
from extensions.sd_dreambooth_extension.dreambooth.db_config import save_config
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd_bmalthais import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import performance_wizard, \
|
||||
training_wizard, training_wizard_person, generate_sample_img
|
||||
training_wizard, training_wizard_person, load_model_params
|
||||
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import get_db_models, log_memory
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import get_db_models, log_memory, generate_sample_img, \
|
||||
debug_prompts
|
||||
from modules import script_callbacks, sd_models, shared
|
||||
from modules.ui import setup_progressbar, gr_show, wrap_gradio_call, create_refresh_button
|
||||
from webui import wrap_gradio_gpu_call
|
||||
|
|
@ -67,7 +68,7 @@ def on_ui_tabs():
|
|||
db_new_model_extract_ema = gr.Checkbox(label='Extract EMA Weights', value=False)
|
||||
db_new_model_scheduler = gr.Dropdown(label='Scheduler', choices=["pndm", "lms", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"],
|
||||
value="euler-ancestral")
|
||||
value="ddim")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
|
|
@ -99,10 +100,11 @@ def on_ui_tabs():
|
|||
db_lr_warmup_steps = gr.Number(label="Learning Rate Warmup Steps", precision=0, value=0)
|
||||
|
||||
with gr.Column():
|
||||
gr.HTML(value="Instance Image Processing")
|
||||
gr.HTML(value="Image Processing")
|
||||
db_resolution = gr.Number(label="Resolution", precision=0, value=512)
|
||||
db_center_crop = gr.Checkbox(label="Center Crop", value=False)
|
||||
db_hflip = gr.Checkbox(label="Apply Horizontal Flip", value=True)
|
||||
db_save_class_txt = gr.Checkbox(label="Save Class Captions to txt", value=False)
|
||||
db_pretrained_vae_name_or_path = gr.Textbox(label='Pretrained VAE Name or Path',
|
||||
placeholder="Leave blank to use base model VAE.",
|
||||
value="")
|
||||
|
|
@ -183,6 +185,7 @@ def on_ui_tabs():
|
|||
|
||||
with gr.Tab("Debugging"):
|
||||
with gr.Column():
|
||||
db_debug_prompts = gr.Button(value="Preview Prompts")
|
||||
db_generate_sample = gr.Button(value="Generate Sample Image")
|
||||
db_sample_prompt = gr.Textbox(label="Sample Prompt")
|
||||
db_sample_seed = gr.Textbox(label="Sample Seed")
|
||||
|
|
@ -201,6 +204,13 @@ def on_ui_tabs():
|
|||
"one": "one",
|
||||
"two": "two"
|
||||
}
|
||||
|
||||
db_debug_prompts.click(
|
||||
fn=debug_prompts,
|
||||
inputs=[db_model_name],
|
||||
outputs=[db_status]
|
||||
)
|
||||
|
||||
db_save_params.click(
|
||||
fn=save_config,
|
||||
inputs=[
|
||||
|
|
@ -215,6 +225,7 @@ def on_ui_tabs():
|
|||
db_gradient_accumulation_steps,
|
||||
db_gradient_checkpointing,
|
||||
db_half_model,
|
||||
db_has_ema,
|
||||
db_hflip,
|
||||
db_learning_rate,
|
||||
db_lr_scheduler,
|
||||
|
|
@ -232,9 +243,11 @@ def on_ui_tabs():
|
|||
db_resolution,
|
||||
db_revision,
|
||||
db_sample_batch_size,
|
||||
db_save_class_txt,
|
||||
db_save_embedding_every,
|
||||
db_save_preview_every,
|
||||
db_scale_lr,
|
||||
db_scheduler,
|
||||
db_src,
|
||||
db_train_batch_size,
|
||||
db_train_text_encoder,
|
||||
|
|
@ -242,6 +255,7 @@ def on_ui_tabs():
|
|||
db_use_concepts,
|
||||
db_use_cpu,
|
||||
db_use_ema,
|
||||
db_v2,
|
||||
c1_class_data_dir,
|
||||
c1_class_guidance_scale,
|
||||
c1_class_infer_steps,
|
||||
|
|
@ -318,7 +332,6 @@ def on_ui_tabs():
|
|||
db_gradient_accumulation_steps,
|
||||
db_gradient_checkpointing,
|
||||
db_half_model,
|
||||
db_has_ema,
|
||||
db_hflip,
|
||||
db_learning_rate,
|
||||
db_lr_scheduler,
|
||||
|
|
@ -327,27 +340,23 @@ def on_ui_tabs():
|
|||
db_max_token_length,
|
||||
db_max_train_steps,
|
||||
db_mixed_precision,
|
||||
db_model_path,
|
||||
db_not_cache_latents,
|
||||
db_num_train_epochs,
|
||||
db_pad_tokens,
|
||||
db_pretrained_vae_name_or_path,
|
||||
db_prior_loss_weight,
|
||||
db_resolution,
|
||||
db_revision,
|
||||
db_sample_batch_size,
|
||||
db_save_class_txt,
|
||||
db_save_embedding_every,
|
||||
db_save_preview_every,
|
||||
db_scale_lr,
|
||||
db_scheduler,
|
||||
db_src,
|
||||
db_train_batch_size,
|
||||
db_train_text_encoder,
|
||||
db_use_8bit_adam,
|
||||
db_use_concepts,
|
||||
db_use_cpu,
|
||||
db_use_ema,
|
||||
db_v2,
|
||||
c1_class_data_dir,
|
||||
c1_class_guidance_scale,
|
||||
c1_class_infer_steps,
|
||||
|
|
@ -427,6 +436,12 @@ def on_ui_tabs():
|
|||
]
|
||||
)
|
||||
|
||||
db_model_name.change(
|
||||
fn=load_model_params,
|
||||
inputs=[db_model_name],
|
||||
outputs=[db_model_path, db_revision, db_v2, db_has_ema, db_src, db_scheduler, db_status]
|
||||
)
|
||||
|
||||
db_use_concepts.change(
|
||||
fn=lambda x: {
|
||||
concept_tab: gr_show(x is True)
|
||||
|
|
@ -506,7 +521,7 @@ def on_ui_tabs():
|
|||
)
|
||||
|
||||
db_generate_checkpoint.click(
|
||||
fn=wrap_gradio_gpu_call(compile_checkpoint, extra_outputs=[gr.update()]),
|
||||
fn=wrap_gradio_gpu_call(compile_checkpoint),
|
||||
_js="db_start_progress",
|
||||
inputs=[
|
||||
db_model_name,
|
||||
|
|
@ -532,7 +547,8 @@ def on_ui_tabs():
|
|||
db_new_model_extract_ema
|
||||
],
|
||||
outputs=[
|
||||
db_model_name, db_model_path, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_status
|
||||
db_model_name, db_model_path, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution,
|
||||
db_status
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -541,7 +557,8 @@ def on_ui_tabs():
|
|||
_js="db_save_start_progress",
|
||||
inputs=[
|
||||
db_model_name,
|
||||
db_train_imagic_only
|
||||
db_train_imagic_only,
|
||||
db_use_subdir
|
||||
],
|
||||
outputs=[
|
||||
db_status,
|
||||
|
|
|
|||
Loading…
Reference in New Issue