One set of imports to rule them all.

pull/1045/head
d8ahazard 2023-03-08 12:51:24 -06:00
parent 7e8ade8741
commit 3e8ddce4dd
33 changed files with 206 additions and 374 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
.idea/*
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/**

View File

View File

@ -5,15 +5,10 @@ from typing import List, Dict
from pydantic import BaseModel
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dreambooth.utils.image_utils import get_scheduler_names # noqa
from dreambooth import shared # noqa
from dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.utils.image_utils import get_scheduler_names # noqa
from dreambooth.utils.utils import list_attention
# Keys to save, replacing our dumb __init__ method
save_keys = []
@ -168,10 +163,6 @@ class DreamboothConfig(BaseModel):
if "db_" in key:
key = key.replace("db_", "")
if key == "attention" and value == "flash_attention":
try:
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention
except:
from dreambooth.dreambooth.utils.utils import list_attention # noqa
value = list_attention()[-1]
print(f"Replacing flash attention in config to {value}")

View File

@ -1,9 +1,6 @@
from PIL import Image
try:
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
except:
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dataclasses.db_config import DreamboothConfig
class TrainResult:

View File

@ -1,10 +1,7 @@
import random
from typing import Tuple
try:
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
except:
from dreambooth.dreambooth.dataset.db_dataset import DbDataset # noqa
from dreambooth.dataset.db_dataset import DbDataset
class BucketSampler:

View File

@ -3,22 +3,14 @@ import random
from torch.utils.data import Dataset
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter, \
make_bucket_resolutions, \
sort_prompts, get_images
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.image_utils import FilenameTextGetter, make_bucket_resolutions, sort_prompts, get_images # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth import shared
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import status
from dreambooth.utils.image_utils import FilenameTextGetter, \
make_bucket_resolutions, \
sort_prompts, get_images
from helpers.mytqdm import mytqdm
class ClassDataset(Dataset):

View File

@ -9,22 +9,13 @@ import torch.utils.data
from torchvision.transforms import transforms
from transformers import CLIPTokenizer
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import make_bucket_resolutions, \
closest_resolution, shuffle_tags, open_and_trim
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import build_strict_tokens
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.image_utils import make_bucket_resolutions, closest_resolution, open_and_trim # noqa
from dreambooth.dreambooth.utils.text_utils import build_strict_tokens # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth import shared
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import status
from dreambooth.utils.image_utils import make_bucket_resolutions, \
closest_resolution, shuffle_tags, open_and_trim
from dreambooth.utils.text_utils import build_strict_tokens
from helpers.mytqdm import mytqdm
class DbDataset(torch.utils.data.Dataset):
"""

View File

@ -3,15 +3,10 @@ import random
from PIL import Image
try:
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
closest_resolution, make_bucket_resolutions
except:
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.utils.image_utils import get_images, FilenameTextGetter, closest_resolution, make_bucket_resolutions # noqa
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
closest_resolution, make_bucket_resolutions
class SampleDataset:

View File

@ -14,25 +14,16 @@ import torch
from diffusers import UNet2DConditionModel
from torch import Tensor, nn
try:
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_lora_to_model
except:
from dreambooth.dreambooth import shared as shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, \
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path # noqa
from dreambooth.dreambooth.utils.utils import printi # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.lora_diffusion.lora import merge_lora_to_model # noqa
from dreambooth import shared as shared
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from dreambooth.shared import status
from dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
from dreambooth.utils.utils import printi
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import merge_lora_to_model
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)

View File

@ -25,12 +25,8 @@ import traceback
import torch
import torch.backends.cudnn
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.utils.utils import cleanup # noqa
from dreambooth import shared
from dreambooth.utils.utils import cleanup
def should_reduce_batch_size(exception: Exception) -> bool:

View File

@ -29,22 +29,13 @@ from huggingface_hub import HfApi, hf_hub_download
from omegaconf import OmegaConf
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, \
enable_safe_unpickle
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, enable_safe_unpickle # noqa
from dreambooth.dreambooth.utils.utils import printi # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.dreambooth.utils.image_utils import get_scheduler_class # noqa
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, \
enable_safe_unpickle
from dreambooth.utils.utils import printi
from helpers.mytqdm import mytqdm
from dreambooth.utils.image_utils import get_scheduler_class
from diffusers import (
AutoencoderKL,
@ -1038,6 +1029,8 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
status
"""
from dreambooth.ui_functions import gr_update
has_ema = False
v2 = False
revision = 0
@ -1146,10 +1139,7 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
if from_hub:
result_status = "Model fetched from hub."
db_config.save()
try:
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
except:
from dreambooth.dreambooth.ui_functions import gr_update # noqa
return gr_update(choices=sorted(get_db_models()), value=new_model_name), \
db_config.model_dir, \
revision, \
@ -1338,10 +1328,6 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
enable_safe_unpickle()
printi(result_status)
try:
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
except:
from dreambooth.dreambooth.ui_functions import gr_update # noqa
return gr_update(choices=sorted(get_db_models()), value=new_model_name), \
model_dir, \

View File

@ -1,10 +1,7 @@
import os
import secrets
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth import shared
db_path = os.path.join(shared.models_path, "dreambooth")
secret_file = os.path.join(db_path, "secret.txt")

View File

@ -28,53 +28,30 @@ from transformers import AutoTokenizer
try:
from extensions.sd_dreambooth_extension.helpers.log_parser import LogParser
from extensions.sd_dreambooth_extension.dreambooth import xattention, shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.train_result import TrainResult
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
from extensions.sd_dreambooth_extension.dreambooth.dataset.sample_dataset import SampleDataset
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.memory import find_executable_batch_size
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
from extensions.sd_dreambooth_extension.dreambooth.shared import status, load_auto_settings
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image, get_scheduler_class
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle, xformerify, torch2ify
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import encode_hidden_state
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup, printm
from extensions.sd_dreambooth_extension.dreambooth.webhook import send_training_update
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
from extensions.sd_dreambooth_extension.helpers.ema_model import EMAModel
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.lora_diffusion.extra_networks import save_extra_networks
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, \
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
except:
from dreambooth.helpers.log_parser import LogParser
from dreambooth.dreambooth import xattention, shared # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.dataclasses.train_result import TrainResult # noqa
from dreambooth.dreambooth.dataset.bucket_sampler import BucketSampler # noqa
from dreambooth.dreambooth.dataset.sample_dataset import SampleDataset # noqa
from dreambooth.dreambooth.diff_to_sd import compile_checkpoint # noqa
from dreambooth.dreambooth.memory import find_executable_batch_size # noqa
from dreambooth.dreambooth.optimization import UniversalScheduler # noqa
from dreambooth.dreambooth.shared import status, load_auto_settings # noqa
from dreambooth.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset # noqa
from dreambooth.dreambooth.utils.image_utils import db_save_image, get_scheduler_class # noqa
from dreambooth.dreambooth.utils.model_utils import unload_system_models, import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle # noqa
from dreambooth.dreambooth.utils.text_utils import encode_hidden_state # noqa
from dreambooth.dreambooth.utils.utils import cleanup, printm # noqa
from dreambooth.dreambooth.webhook import send_training_update # noqa
from dreambooth.dreambooth.xattention import optim_to # noqa
from dreambooth.helpers.ema_model import EMAModel # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.lora_diffusion.extra_networks import save_extra_networks # noqa
from dreambooth.lora_diffusion.lora import save_lora_weight, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module # noqa
from helpers.log_parser import LogParser
from dreambooth import xattention, shared
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataclasses.train_result import TrainResult
from dreambooth.dataset.bucket_sampler import BucketSampler
from dreambooth.dataset.sample_dataset import SampleDataset
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.memory import find_executable_batch_size
from dreambooth.optimization import UniversalScheduler
from dreambooth.shared import status, load_auto_settings
from dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
from dreambooth.utils.image_utils import db_save_image, get_scheduler_class
from dreambooth.utils.model_utils import unload_system_models, \
import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle, xformerify, torch2ify
from dreambooth.utils.text_utils import encode_hidden_state
from dreambooth.utils.utils import cleanup, printm
from dreambooth.webhook import send_training_update
from dreambooth.xattention import optim_to
from helpers.ema_model import EMAModel
from helpers.mytqdm import mytqdm
from lora_diffusion.extra_networks import save_extra_networks
from lora_diffusion.lora import save_lora_weight, \
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
from dreambooth.deis_velocity import get_velocity
logger = logging.getLogger(__name__)
# define a Handler which writes DEBUG messages or higher to the sys.stderr

View File

@ -13,16 +13,10 @@ from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.image_utils import list_features, is_image # noqa
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.shared import status
from dreambooth.utils.image_utils import list_features, is_image
logger = get_logger(__name__)

View File

@ -15,41 +15,27 @@ import torch.utils.data.dataloader
from accelerate import find_executable_batch_size
from diffusers.utils import logging as dl
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig, \
sanitize_name
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.shared import status, run
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, db_save_image, \
make_bucket_resolutions, get_dim, closest_resolution, open_and_trim
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, get_lora_models, get_checkpoint_match, get_model_snapshots
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm, cleanup
from extensions.sd_dreambooth_extension.helpers.image_builder import ImageBuilder
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses import db_config # noqa
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig, sanitize_name # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.dataset.bucket_sampler import BucketSampler # noqa
from dreambooth.dreambooth.dataset.class_dataset import ClassDataset # noqa
from dreambooth.dreambooth.optimization import UniversalScheduler # noqa
from dreambooth.dreambooth.sd_to_diff import extract_checkpoint # noqa
from dreambooth.dreambooth.shared import status, run # noqa
from dreambooth.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers # noqa
from dreambooth.dreambooth.utils.image_utils import get_images, db_save_image, make_bucket_resolutions, get_dim, closest_resolution, open_and_trim # noqa
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, get_lora_models, get_checkpoint_match # noqa
from dreambooth.dreambooth.utils.utils import printm, cleanup # noqa
from dreambooth.helpers.image_builder import ImageBuilder # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth import shared
from dreambooth.dataclasses import db_config
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig, \
sanitize_name
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataset.bucket_sampler import BucketSampler
from dreambooth.dataset.class_dataset import ClassDataset
from dreambooth.optimization import UniversalScheduler
from dreambooth.sd_to_diff import extract_checkpoint
from dreambooth.shared import status, run
from dreambooth.train_dreambooth import main
from dreambooth.train_imagic import train_imagic
from dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
from dreambooth.utils.image_utils import get_images, db_save_image, \
make_bucket_resolutions, get_dim, closest_resolution, open_and_trim
from dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, get_lora_models, get_checkpoint_match, get_model_snapshots
from dreambooth.utils.utils import printm, cleanup
from helpers.image_builder import ImageBuilder
from helpers.mytqdm import mytqdm
from postinstall import actual_install
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
@ -665,19 +651,10 @@ def start_training(model_dir: str, use_txt2img: bool = True):
if config.train_imagic:
status.textinfo = "Initializing imagic training..."
print(status.textinfo)
try:
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
except:
from dreambooth.dreambooth.train_imagic import train_imagic # noqa
result = train_imagic(config)
else:
status.textinfo = "Initializing dreambooth training..."
print(status.textinfo)
try:
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
except:
from dreambooth.dreambooth.train_dreambooth import main # noqa
result = main(use_txt2img=use_txt2img)
config = result.config
@ -725,10 +702,6 @@ def reload_extension():
except Exception as e:
print(f"Couldn't import module: {re_add}")
try:
from extensions.sd_dreambooth_extension.postinstall import actual_install # noqa
except:
from dreambooth.postinstall import actual_install # noqa
actual_install()

View File

@ -5,34 +5,24 @@ from typing import List
from accelerate import Accelerator
from transformers import AutoTokenizer
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
from extensions.sd_dreambooth_extension.helpers.image_builder import ImageBuilder
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig, from_file # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.dataset.class_dataset import ClassDataset # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.image_utils import db_save_image # noqa
from dreambooth.dreambooth.utils.utils import cleanup # noqa
from dreambooth.helpers.image_builder import ImageBuilder # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataset.class_dataset import ClassDataset
from dreambooth.dataset.db_dataset import DbDataset
from dreambooth.shared import status
from dreambooth.utils.image_utils import db_save_image
from dreambooth.utils.utils import cleanup
from helpers.image_builder import ImageBuilder
from helpers.mytqdm import mytqdm
def generate_dataset(model_name: str, instance_prompts: List[PromptData] = None, class_prompts: List[PromptData] = None,
batch_size=None, tokenizer=None, vae=None, debug=True, model_dir=""):
if debug:
print("Generating dataset.")
from dreambooth.ui_functions import gr_update
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
db_gallery = gr_update(value=None)
db_prompt_list = gr_update(value=None)
db_status = gr_update(value=None)
@ -57,10 +47,6 @@ def generate_dataset(model_name: str, instance_prompts: List[PromptData] = None,
tokens = []
print(f"Found {len(class_prompts)} reg images.")
try:
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
except:
from dreambooth.dreambooth.dataset.db_dataset import DbDataset
print("Preparing dataset...")

View File

@ -20,18 +20,11 @@ import numpy as np
import torch
import torch.utils.checkpoint
try:
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.shared import status
except:
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from helpers.mytqdm import mytqdm
from dreambooth import shared
from dreambooth.shared import status
def get_dim(filename, max_res):

View File

@ -8,14 +8,9 @@ import torch
from diffusers.utils import is_xformers_available
from transformers import PretrainedConfig
try:
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup # noqa
except:
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.utils.utils import cleanup # noqa
from dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth import shared # noqa
from dreambooth.utils.utils import cleanup # noqa
checkpoints_list = {}
checkpoint_alisases = {}

View File

@ -11,16 +11,14 @@ from typing import Optional
import importlib_metadata
from packaging import version
from dreambooth import shared
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import torch
from huggingface_hub import HfFolder, whoami
try:
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth.shared import status
except:
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.dreambooth.shared import status # noqa
from helpers.mytqdm import mytqdm
from dreambooth.shared import status
def printi(msg, params=None, log=True):
@ -49,7 +47,6 @@ def sanitize_name(name):
def printm(msg=""):
from extensions.sd_dreambooth_extension.dreambooth import shared
if shared.debug:
allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)

View File

@ -6,12 +6,8 @@ from typing import Union, List
import discord_webhook
from PIL import Image
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import image_grid
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.utils.image_utils import image_grid # noqa
from dreambooth import shared
from dreambooth.utils.image_utils import image_grid
class DreamboothWebhookTarget(Enum):

View File

@ -8,29 +8,18 @@ from PIL import Image
from accelerate import Accelerator
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.shared import disable_safe_unpickle
from extensions.sd_dreambooth_extension.dreambooth.utils import image_utils
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_checkpoint_match, \
reload_system_models, \
enable_safe_unpickle, disable_safe_unpickle, unload_system_models, xformerify
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
get_target_module
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
from dreambooth.dreambooth.shared import disable_safe_unpickle # noqa
from dreambooth.dreambooth.utils import image_utils # noqa
from dreambooth.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class # noqa
from dreambooth.dreambooth.utils.model_utils import get_checkpoint_match, reload_system_models, enable_safe_unpickle, disable_safe_unpickle, unload_system_models # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, get_target_module # noqa
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import disable_safe_unpickle
from dreambooth.utils import image_utils
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from dreambooth.utils.model_utils import get_checkpoint_match, \
reload_system_models, \
enable_safe_unpickle, disable_safe_unpickle, unload_system_models, xformerify
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
get_target_module
class ImageBuilder:

View File

@ -9,10 +9,7 @@ from pandas import DataFrame
from pandas.plotting._matplotlib.style import get_standard_colors
from tensorboard.compat.proto import event_pb2
try:
from extensions.sd_dreambooth_extension.dreambooth.shared import status
except:
from dreambooth.dreambooth.shared import status
from dreambooth.shared import status
@dataclass
@ -248,9 +245,9 @@ class LogParser:
}
try:
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file # noqa
from dreambooth.dataclasses.db_config import from_file # noqa
except:
from dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
from core.modules.dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
model_config = from_file(model_name)
print(f"Model name: {model_name}")
if model_config is None:

View File

@ -2,10 +2,7 @@ from typing import Iterable
from tqdm import tqdm
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth import shared
class mytqdm(tqdm):

View File

@ -3,10 +3,7 @@ import os
import subprocess
from typing import Union, Dict
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth import shared
store_file = os.path.join(shared.dreambooth_models_path, "revision.txt")
change_file = os.path.join(shared.dreambooth_models_path, "changelog.txt")

View File

@ -1,6 +1,30 @@
import os
import sys
is_auto = False
try:
from extensions.sd_dreambooth_extension.postinstall import actual_install
from modules import shared
is_auto = True
except:
from dreambooth.postinstall import actual_install
pass
if not is_auto:
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
print(f"Base dir (sdplus) is: {base_dir} from {__file__}")
ext_dir = os.path.join(base_dir, 'core', 'modules', 'dreambooth')
else:
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(f"Base dir (auto1111) is: {base_dir} from {__file__}")
ext_dir = os.path.join(base_dir, 'extensions', 'sd_dreambooth_extension')
if ext_dir not in sys.path:
print(f"Appending (install): {ext_dir}")
sys.path.insert(0, ext_dir)
else:
print(f"Ext dir already in path? {ext_dir}")
print(sys.path)
from postinstall import actual_install
actual_install()

View File

View File

@ -4,10 +4,7 @@ from typing import Optional, Set
import torch.nn as nn
try:
from extensions.sd_dreambooth_extension.lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d
except:
from dreambooth.lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d # noqa
from lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d
def _find_modules_with_ancestor(

View File

@ -8,11 +8,7 @@ from safetensors.torch import safe_open
from safetensors.torch import save_file as safe_save
from torch import dtype
try:
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import disable_safe_unpickle, \
enable_safe_unpickle
except:
from dreambooth.dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle # noqa
from dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle
class LoraInjectedLinear(nn.Module):

View File

@ -11,6 +11,8 @@ import sysconfig
import git
import requests
from dreambooth import shared
def run(command, desc=None, errdesc=None, custom_env=None, live=True):
if desc:
@ -38,10 +40,6 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=True):
def actual_install():
if os.environ.get("PUBLIC_KEY", None):
print("Docker, returning.")
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
except:
from dreambooth.dreambooth import shared # noqa
shared.launch_error = None
return
if sys.version_info < (3, 8):
@ -257,7 +255,7 @@ def actual_install():
try:
repo = git.Repo(base_dir)
revision = repo.rev_parse("HEAD")
app_repo = git.Repo(os.path.join(base_dir, "..", ".."))
app_repo = git.Repo(os.path.join(base_dir, "../../..", ".."))
app_revision = app_repo.rev_parse("HEAD")
except:
pass

View File

@ -2,17 +2,17 @@ accelerate==0.16.0
albumentations~=1.3.0
bitsandbytes==0.35.4
diffusers==0.13.1
gitpython~=3.1.31
discord-webhook~=1.1.0
fastapi
ftfy~=6.1.1
modelcards~=0.1.6
tensorboard
tensorflow==2.11.0; sys_platform != 'darwin' or platform_machine != 'arm64'
tensorflow-macos==2.11.0; sys_platform == 'darwin' and platform_machine == 'arm64'
gitpython~=3.1.31
lion-pytorch~=0.0.7
mediapipe-silicon; sys_platform == 'darwin'
mediapipe; sys_platform != 'darwin'
modelcards~=0.1.6
tensorboard
tensorflow-macos==2.11.0; sys_platform == 'darwin' and platform_machine == 'arm64'
tensorflow==2.11.0; sys_platform != 'darwin' or platform_machine != 'arm64'
tqdm~=4.64.1
transformers~=4.26.1
discord-webhook~=1.1.0
lion-pytorch~=0.0.7
xformers==0.0.17.dev464

0
scripts/__init__.py Normal file
View File

View File

@ -22,36 +22,26 @@ from starlette import status
from starlette.requests import Request
try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
from extensions.sd_dreambooth_extension.dreambooth.shared import DreamState
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import create_model, generate_samples, \
from dreambooth import shared
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.secret import get_secret
from dreambooth.shared import DreamState
from dreambooth.ui_functions import create_model, generate_samples, \
start_training
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, get_lora_models
from dreambooth.utils.gen_utils import generate_classifiers
from dreambooth.utils.image_utils import get_images
from dreambooth.utils.model_utils import get_db_models, get_lora_models
except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
from dreambooth.dreambooth.diff_to_sd import compile_checkpoint # noqa
from dreambooth.dreambooth.secret import get_secret # noqa
from dreambooth.dreambooth.shared import DreamState # noqa
from dreambooth.dreambooth.ui_functions import create_model, generate_samples, start_training # noqa
from dreambooth.dreambooth.utils.gen_utils import generate_classifiers # noqa
from dreambooth.dreambooth.utils.image_utils import get_images # noqa
from dreambooth.dreambooth.utils.model_utils import get_db_models, get_lora_models # noqa
pass
print("Exception importing api")
traceback.print_exc()
if os.environ.get("DEBUG_API", False):
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
print("No, really, API loaded, wtf...")
class InstanceData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file or URL")
@ -156,6 +146,7 @@ def file_to_base64(file_path) -> str:
def dreambooth_api(_, app: FastAPI):
print("API LOAD")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
@ -941,4 +932,5 @@ try:
script_callbacks.on_app_started(dreambooth_api)
logger.debug("SD-Webui API layer loaded")
except:
logger.debug("Unable to import script callbacks.")
pass

View File

@ -4,23 +4,25 @@ from typing import List
import gradio as gr
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import save_config, from_file
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret, create_secret, clear_secret
from extensions.sd_dreambooth_extension.dreambooth.shared import status, get_launch_errors
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import performance_wizard, \
from dreambooth.dataclasses import db_config
from dreambooth.dataclasses.db_config import save_config, from_file
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.secret import get_secret, create_secret, clear_secret
from dreambooth.shared import status, get_launch_errors
from dreambooth.ui_functions import performance_wizard, \
training_wizard, training_wizard_person, load_model_params, ui_classifiers, debug_buckets, create_model, \
generate_samples, load_params, start_training, update_extension, start_crop
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, \
from dreambooth.utils.image_utils import get_scheduler_names
from dreambooth.utils.model_utils import get_db_models, \
get_sorted_lora_models, get_model_snapshots
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention, \
from dreambooth.utils.utils import list_attention, \
list_floats, wrap_gpu_call, printm, list_optimizer
from extensions.sd_dreambooth_extension.dreambooth.webhook import save_and_test_webhook
from extensions.sd_dreambooth_extension.helpers.version_helper import check_updates
from dreambooth.webhook import save_and_test_webhook
from helpers.version_helper import check_updates
from helpers.log_parser import LogParser
from modules import script_callbacks, sd_models
from modules.ui import gr_show, create_refresh_button
from extensions.sd_dreambooth_extension.helpers.log_parser import LogParser
params_to_save = []
params_to_load = []
@ -849,7 +851,6 @@ def on_ui_tabs():
ui_keys.append("db_status")
params_to_load.append(db_status)
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
db_config.save_keys = save_keys
db_config.ui_keys = ui_keys