One set of imports to rule them all.
parent
7e8ade8741
commit
3e8ddce4dd
|
|
@ -1,4 +1,4 @@
|
|||
.idea/*
|
||||
.idea
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
**/__pycache__/**
|
||||
|
|
|
|||
|
|
@ -5,15 +5,10 @@ from typing import List, Dict
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
|
||||
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import get_scheduler_names # noqa
|
||||
from dreambooth import shared # noqa
|
||||
from dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.utils.image_utils import get_scheduler_names # noqa
|
||||
from dreambooth.utils.utils import list_attention
|
||||
|
||||
# Keys to save, replacing our dumb __init__ method
|
||||
save_keys = []
|
||||
|
|
@ -168,10 +163,6 @@ class DreamboothConfig(BaseModel):
|
|||
if "db_" in key:
|
||||
key = key.replace("db_", "")
|
||||
if key == "attention" and value == "flash_attention":
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention
|
||||
except:
|
||||
from dreambooth.dreambooth.utils.utils import list_attention # noqa
|
||||
value = list_attention()[-1]
|
||||
print(f"Replacing flash attention in config to {value}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
from PIL import Image
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
except:
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
|
||||
|
||||
class TrainResult:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
import random
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
|
||||
except:
|
||||
from dreambooth.dreambooth.dataset.db_dataset import DbDataset # noqa
|
||||
from dreambooth.dataset.db_dataset import DbDataset
|
||||
|
||||
|
||||
class BucketSampler:
|
||||
|
|
|
|||
|
|
@ -3,22 +3,14 @@ import random
|
|||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter, \
|
||||
make_bucket_resolutions, \
|
||||
sort_prompts, get_images
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import FilenameTextGetter, make_bucket_resolutions, sort_prompts, get_images # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import FilenameTextGetter, \
|
||||
make_bucket_resolutions, \
|
||||
sort_prompts, get_images
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
|
||||
class ClassDataset(Dataset):
|
||||
|
|
|
|||
|
|
@ -9,22 +9,13 @@ import torch.utils.data
|
|||
from torchvision.transforms import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import make_bucket_resolutions, \
|
||||
closest_resolution, shuffle_tags, open_and_trim
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import build_strict_tokens
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import make_bucket_resolutions, closest_resolution, open_and_trim # noqa
|
||||
from dreambooth.dreambooth.utils.text_utils import build_strict_tokens # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import make_bucket_resolutions, \
|
||||
closest_resolution, shuffle_tags, open_and_trim
|
||||
from dreambooth.utils.text_utils import build_strict_tokens
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
class DbDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,15 +3,10 @@ import random
|
|||
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
|
||||
closest_resolution, make_bucket_resolutions
|
||||
except:
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import get_images, FilenameTextGetter, closest_resolution, make_bucket_resolutions # noqa
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
|
||||
closest_resolution, make_bucket_resolutions
|
||||
|
||||
|
||||
class SampleDataset:
|
||||
|
|
|
|||
|
|
@ -14,25 +14,16 @@ import torch
|
|||
from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor, nn
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
|
||||
reload_system_models, \
|
||||
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_lora_to_model
|
||||
except:
|
||||
from dreambooth.dreambooth import shared as shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, \
|
||||
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path # noqa
|
||||
from dreambooth.dreambooth.utils.utils import printi # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.lora_diffusion.lora import merge_lora_to_model # noqa
|
||||
|
||||
from dreambooth import shared as shared
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.model_utils import unload_system_models, \
|
||||
reload_system_models, \
|
||||
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
|
||||
from dreambooth.utils.utils import printi
|
||||
from helpers.mytqdm import mytqdm
|
||||
from lora_diffusion.lora import merge_lora_to_model
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
|
|
|
|||
|
|
@ -25,12 +25,8 @@ import traceback
|
|||
import torch
|
||||
import torch.backends.cudnn
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.utils.utils import cleanup # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.utils.utils import cleanup
|
||||
|
||||
|
||||
def should_reduce_batch_size(exception: Exception) -> bool:
|
||||
|
|
|
|||
|
|
@ -29,22 +29,13 @@ from huggingface_hub import HfApi, hf_hub_download
|
|||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, \
|
||||
enable_safe_unpickle
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class
|
||||
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, enable_safe_unpickle # noqa
|
||||
from dreambooth.dreambooth.utils.utils import printi # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import get_scheduler_class # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, \
|
||||
enable_safe_unpickle
|
||||
from dreambooth.utils.utils import printi
|
||||
from helpers.mytqdm import mytqdm
|
||||
from dreambooth.utils.image_utils import get_scheduler_class
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
|
@ -1038,6 +1029,8 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
|
|||
|
||||
status
|
||||
"""
|
||||
from dreambooth.ui_functions import gr_update
|
||||
|
||||
has_ema = False
|
||||
v2 = False
|
||||
revision = 0
|
||||
|
|
@ -1146,10 +1139,7 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
|
|||
if from_hub:
|
||||
result_status = "Model fetched from hub."
|
||||
db_config.save()
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
|
||||
except:
|
||||
from dreambooth.dreambooth.ui_functions import gr_update # noqa
|
||||
|
||||
return gr_update(choices=sorted(get_db_models()), value=new_model_name), \
|
||||
db_config.model_dir, \
|
||||
revision, \
|
||||
|
|
@ -1338,10 +1328,6 @@ def extract_checkpoint(new_model_name: str, checkpoint_file: str, from_hub=False
|
|||
enable_safe_unpickle()
|
||||
printi(result_status)
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
|
||||
except:
|
||||
from dreambooth.dreambooth.ui_functions import gr_update # noqa
|
||||
|
||||
return gr_update(choices=sorted(get_db_models()), value=new_model_name), \
|
||||
model_dir, \
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
import os
|
||||
import secrets
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth import shared
|
||||
|
||||
db_path = os.path.join(shared.models_path, "dreambooth")
|
||||
secret_file = os.path.join(db_path, "secret.txt")
|
||||
|
|
|
|||
|
|
@ -28,53 +28,30 @@ from transformers import AutoTokenizer
|
|||
|
||||
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.helpers.log_parser import LogParser
|
||||
from extensions.sd_dreambooth_extension.dreambooth import xattention, shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.train_result import TrainResult
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.memory import find_executable_batch_size
|
||||
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status, load_auto_settings
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image, get_scheduler_class
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
|
||||
import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle, xformerify, torch2ify
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import encode_hidden_state
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup, printm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.webhook import send_training_update
|
||||
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
|
||||
from extensions.sd_dreambooth_extension.helpers.ema_model import EMAModel
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
from extensions.sd_dreambooth_extension.lora_diffusion.extra_networks import save_extra_networks
|
||||
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, \
|
||||
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
|
||||
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
|
||||
except:
|
||||
from dreambooth.helpers.log_parser import LogParser
|
||||
from dreambooth.dreambooth import xattention, shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.dataclasses.train_result import TrainResult # noqa
|
||||
from dreambooth.dreambooth.dataset.bucket_sampler import BucketSampler # noqa
|
||||
from dreambooth.dreambooth.dataset.sample_dataset import SampleDataset # noqa
|
||||
from dreambooth.dreambooth.diff_to_sd import compile_checkpoint # noqa
|
||||
from dreambooth.dreambooth.memory import find_executable_batch_size # noqa
|
||||
from dreambooth.dreambooth.optimization import UniversalScheduler # noqa
|
||||
from dreambooth.dreambooth.shared import status, load_auto_settings # noqa
|
||||
from dreambooth.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import db_save_image, get_scheduler_class # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import unload_system_models, import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle # noqa
|
||||
from dreambooth.dreambooth.utils.text_utils import encode_hidden_state # noqa
|
||||
from dreambooth.dreambooth.utils.utils import cleanup, printm # noqa
|
||||
from dreambooth.dreambooth.webhook import send_training_update # noqa
|
||||
from dreambooth.dreambooth.xattention import optim_to # noqa
|
||||
from dreambooth.helpers.ema_model import EMAModel # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.lora_diffusion.extra_networks import save_extra_networks # noqa
|
||||
from dreambooth.lora_diffusion.lora import save_lora_weight, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module # noqa
|
||||
from helpers.log_parser import LogParser
|
||||
from dreambooth import xattention, shared
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataclasses.train_result import TrainResult
|
||||
from dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.memory import find_executable_batch_size
|
||||
from dreambooth.optimization import UniversalScheduler
|
||||
from dreambooth.shared import status, load_auto_settings
|
||||
from dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
|
||||
from dreambooth.utils.image_utils import db_save_image, get_scheduler_class
|
||||
from dreambooth.utils.model_utils import unload_system_models, \
|
||||
import_model_class_from_model_name_or_path, disable_safe_unpickle, enable_safe_unpickle, xformerify, torch2ify
|
||||
from dreambooth.utils.text_utils import encode_hidden_state
|
||||
from dreambooth.utils.utils import cleanup, printm
|
||||
from dreambooth.webhook import send_training_update
|
||||
from dreambooth.xattention import optim_to
|
||||
from helpers.ema_model import EMAModel
|
||||
from helpers.mytqdm import mytqdm
|
||||
from lora_diffusion.extra_networks import save_extra_networks
|
||||
from lora_diffusion.lora import save_lora_weight, \
|
||||
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
|
||||
from dreambooth.deis_velocity import get_velocity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# define a Handler which writes DEBUG messages or higher to the sys.stderr
|
||||
|
|
|
|||
|
|
@ -13,16 +13,10 @@ from torchvision import transforms
|
|||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import list_features, is_image # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import list_features, is_image
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,41 +15,27 @@ import torch.utils.data.dataloader
|
|||
from accelerate import find_executable_batch_size
|
||||
from diffusers.utils import logging as dl
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig, \
|
||||
sanitize_name
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status, run
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, db_save_image, \
|
||||
make_bucket_resolutions, get_dim, closest_resolution, open_and_trim
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
|
||||
reload_system_models, get_lora_models, get_checkpoint_match, get_model_snapshots
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm, cleanup
|
||||
from extensions.sd_dreambooth_extension.helpers.image_builder import ImageBuilder
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses import db_config # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig, sanitize_name # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.dataset.bucket_sampler import BucketSampler # noqa
|
||||
from dreambooth.dreambooth.dataset.class_dataset import ClassDataset # noqa
|
||||
from dreambooth.dreambooth.optimization import UniversalScheduler # noqa
|
||||
from dreambooth.dreambooth.sd_to_diff import extract_checkpoint # noqa
|
||||
from dreambooth.dreambooth.shared import status, run # noqa
|
||||
from dreambooth.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import get_images, db_save_image, make_bucket_resolutions, get_dim, closest_resolution, open_and_trim # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, get_lora_models, get_checkpoint_match # noqa
|
||||
from dreambooth.dreambooth.utils.utils import printm, cleanup # noqa
|
||||
from dreambooth.helpers.image_builder import ImageBuilder # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses import db_config
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig, \
|
||||
sanitize_name
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from dreambooth.dataset.class_dataset import ClassDataset
|
||||
from dreambooth.optimization import UniversalScheduler
|
||||
from dreambooth.sd_to_diff import extract_checkpoint
|
||||
from dreambooth.shared import status, run
|
||||
from dreambooth.train_dreambooth import main
|
||||
from dreambooth.train_imagic import train_imagic
|
||||
from dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
|
||||
from dreambooth.utils.image_utils import get_images, db_save_image, \
|
||||
make_bucket_resolutions, get_dim, closest_resolution, open_and_trim
|
||||
from dreambooth.utils.model_utils import unload_system_models, \
|
||||
reload_system_models, get_lora_models, get_checkpoint_match, get_model_snapshots
|
||||
from dreambooth.utils.utils import printm, cleanup
|
||||
from helpers.image_builder import ImageBuilder
|
||||
from helpers.mytqdm import mytqdm
|
||||
from postinstall import actual_install
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = logging.StreamHandler()
|
||||
|
|
@ -665,19 +651,10 @@ def start_training(model_dir: str, use_txt2img: bool = True):
|
|||
if config.train_imagic:
|
||||
status.textinfo = "Initializing imagic training..."
|
||||
print(status.textinfo)
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
|
||||
except:
|
||||
from dreambooth.dreambooth.train_imagic import train_imagic # noqa
|
||||
|
||||
result = train_imagic(config)
|
||||
else:
|
||||
status.textinfo = "Initializing dreambooth training..."
|
||||
print(status.textinfo)
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
|
||||
except:
|
||||
from dreambooth.dreambooth.train_dreambooth import main # noqa
|
||||
result = main(use_txt2img=use_txt2img)
|
||||
|
||||
config = result.config
|
||||
|
|
@ -725,10 +702,6 @@ def reload_extension():
|
|||
|
||||
except Exception as e:
|
||||
print(f"Couldn't import module: {re_add}")
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.postinstall import actual_install # noqa
|
||||
except:
|
||||
from dreambooth.postinstall import actual_install # noqa
|
||||
|
||||
actual_install()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,34 +5,24 @@ from typing import List
|
|||
from accelerate import Accelerator
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
|
||||
from extensions.sd_dreambooth_extension.helpers.image_builder import ImageBuilder
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig, from_file # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.dataset.class_dataset import ClassDataset # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import db_save_image # noqa
|
||||
from dreambooth.dreambooth.utils.utils import cleanup # noqa
|
||||
from dreambooth.helpers.image_builder import ImageBuilder # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataset.class_dataset import ClassDataset
|
||||
from dreambooth.dataset.db_dataset import DbDataset
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import db_save_image
|
||||
from dreambooth.utils.utils import cleanup
|
||||
from helpers.image_builder import ImageBuilder
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
|
||||
def generate_dataset(model_name: str, instance_prompts: List[PromptData] = None, class_prompts: List[PromptData] = None,
|
||||
batch_size=None, tokenizer=None, vae=None, debug=True, model_dir=""):
|
||||
if debug:
|
||||
print("Generating dataset.")
|
||||
from dreambooth.ui_functions import gr_update
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
|
||||
db_gallery = gr_update(value=None)
|
||||
db_prompt_list = gr_update(value=None)
|
||||
db_status = gr_update(value=None)
|
||||
|
|
@ -57,10 +47,6 @@ def generate_dataset(model_name: str, instance_prompts: List[PromptData] = None,
|
|||
tokens = []
|
||||
|
||||
print(f"Found {len(class_prompts)} reg images.")
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
|
||||
except:
|
||||
from dreambooth.dreambooth.dataset.db_dataset import DbDataset
|
||||
|
||||
print("Preparing dataset...")
|
||||
|
||||
|
|
|
|||
|
|
@ -20,18 +20,11 @@ import numpy as np
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
except:
|
||||
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from helpers.mytqdm import mytqdm
|
||||
from dreambooth import shared
|
||||
from dreambooth.shared import status
|
||||
|
||||
|
||||
def get_dim(filename, max_res):
|
||||
|
|
|
|||
|
|
@ -8,14 +8,9 @@ import torch
|
|||
from diffusers.utils import is_xformers_available
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup # noqa
|
||||
except:
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.utils.utils import cleanup # noqa
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth import shared # noqa
|
||||
from dreambooth.utils.utils import cleanup # noqa
|
||||
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
|
|
|
|||
|
|
@ -11,16 +11,14 @@ from typing import Optional
|
|||
import importlib_metadata
|
||||
from packaging import version
|
||||
|
||||
from dreambooth import shared
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import torch
|
||||
from huggingface_hub import HfFolder, whoami
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
except:
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.dreambooth.shared import status # noqa
|
||||
from helpers.mytqdm import mytqdm
|
||||
from dreambooth.shared import status
|
||||
|
||||
|
||||
def printi(msg, params=None, log=True):
|
||||
|
|
@ -49,7 +47,6 @@ def sanitize_name(name):
|
|||
|
||||
|
||||
def printm(msg=""):
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
if shared.debug:
|
||||
allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
|
||||
cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,8 @@ from typing import Union, List
|
|||
import discord_webhook
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import image_grid
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import image_grid # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.utils.image_utils import image_grid
|
||||
|
||||
|
||||
class DreamboothWebhookTarget(Enum):
|
||||
|
|
|
|||
|
|
@ -8,29 +8,18 @@ from PIL import Image
|
|||
from accelerate import Accelerator
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import disable_safe_unpickle
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import image_utils
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_checkpoint_match, \
|
||||
reload_system_models, \
|
||||
enable_safe_unpickle, disable_safe_unpickle, unload_system_models, xformerify
|
||||
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
|
||||
from extensions.sd_dreambooth_extension.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
|
||||
get_target_module
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.dataclasses.prompt_data import PromptData # noqa
|
||||
from dreambooth.dreambooth.shared import disable_safe_unpickle # noqa
|
||||
from dreambooth.dreambooth.utils import image_utils # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import get_checkpoint_match, reload_system_models, enable_safe_unpickle, disable_safe_unpickle, unload_system_models # noqa
|
||||
from dreambooth.helpers.mytqdm import mytqdm # noqa
|
||||
from dreambooth.lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, get_target_module # noqa
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.shared import disable_safe_unpickle
|
||||
from dreambooth.utils import image_utils
|
||||
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
|
||||
from dreambooth.utils.model_utils import get_checkpoint_match, \
|
||||
reload_system_models, \
|
||||
enable_safe_unpickle, disable_safe_unpickle, unload_system_models, xformerify
|
||||
from helpers.mytqdm import mytqdm
|
||||
from lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
|
||||
get_target_module
|
||||
|
||||
|
||||
class ImageBuilder:
|
||||
|
|
|
|||
|
|
@ -9,10 +9,7 @@ from pandas import DataFrame
|
|||
from pandas.plotting._matplotlib.style import get_standard_colors
|
||||
from tensorboard.compat.proto import event_pb2
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
except:
|
||||
from dreambooth.dreambooth.shared import status
|
||||
from dreambooth.shared import status
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -248,9 +245,9 @@ class LogParser:
|
|||
}
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file # noqa
|
||||
from dreambooth.dataclasses.db_config import from_file # noqa
|
||||
except:
|
||||
from dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
|
||||
from core.modules.dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
|
||||
model_config = from_file(model_name)
|
||||
print(f"Model name: {model_name}")
|
||||
if model_config is None:
|
||||
|
|
|
|||
|
|
@ -2,10 +2,7 @@ from typing import Iterable
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth import shared
|
||||
|
||||
|
||||
class mytqdm(tqdm):
|
||||
|
|
|
|||
|
|
@ -3,10 +3,7 @@ import os
|
|||
import subprocess
|
||||
from typing import Union, Dict
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth import shared
|
||||
|
||||
store_file = os.path.join(shared.dreambooth_models_path, "revision.txt")
|
||||
change_file = os.path.join(shared.dreambooth_models_path, "changelog.txt")
|
||||
|
|
|
|||
28
install.py
28
install.py
|
|
@ -1,6 +1,30 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
is_auto = False
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.postinstall import actual_install
|
||||
from modules import shared
|
||||
is_auto = True
|
||||
except:
|
||||
from dreambooth.postinstall import actual_install
|
||||
pass
|
||||
|
||||
if not is_auto:
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
print(f"Base dir (sdplus) is: {base_dir} from {__file__}")
|
||||
ext_dir = os.path.join(base_dir, 'core', 'modules', 'dreambooth')
|
||||
else:
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
print(f"Base dir (auto1111) is: {base_dir} from {__file__}")
|
||||
ext_dir = os.path.join(base_dir, 'extensions', 'sd_dreambooth_extension')
|
||||
|
||||
if ext_dir not in sys.path:
|
||||
print(f"Appending (install): {ext_dir}")
|
||||
sys.path.insert(0, ext_dir)
|
||||
else:
|
||||
print(f"Ext dir already in path? {ext_dir}")
|
||||
print(sys.path)
|
||||
|
||||
|
||||
from postinstall import actual_install
|
||||
|
||||
actual_install()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ from typing import Optional, Set
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d
|
||||
except:
|
||||
from dreambooth.lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d # noqa
|
||||
from lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d
|
||||
|
||||
|
||||
def _find_modules_with_ancestor(
|
||||
|
|
|
|||
|
|
@ -8,11 +8,7 @@ from safetensors.torch import safe_open
|
|||
from safetensors.torch import save_file as safe_save
|
||||
from torch import dtype
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import disable_safe_unpickle, \
|
||||
enable_safe_unpickle
|
||||
except:
|
||||
from dreambooth.dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle # noqa
|
||||
from dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle
|
||||
|
||||
|
||||
class LoraInjectedLinear(nn.Module):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import sysconfig
|
|||
import git
|
||||
import requests
|
||||
|
||||
from dreambooth import shared
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live=True):
|
||||
if desc:
|
||||
|
|
@ -38,10 +40,6 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=True):
|
|||
def actual_install():
|
||||
if os.environ.get("PUBLIC_KEY", None):
|
||||
print("Docker, returning.")
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
shared.launch_error = None
|
||||
return
|
||||
if sys.version_info < (3, 8):
|
||||
|
|
@ -257,7 +255,7 @@ def actual_install():
|
|||
try:
|
||||
repo = git.Repo(base_dir)
|
||||
revision = repo.rev_parse("HEAD")
|
||||
app_repo = git.Repo(os.path.join(base_dir, "..", ".."))
|
||||
app_repo = git.Repo(os.path.join(base_dir, "../../..", ".."))
|
||||
app_revision = app_repo.rev_parse("HEAD")
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,17 +2,17 @@ accelerate==0.16.0
|
|||
albumentations~=1.3.0
|
||||
bitsandbytes==0.35.4
|
||||
diffusers==0.13.1
|
||||
gitpython~=3.1.31
|
||||
discord-webhook~=1.1.0
|
||||
fastapi
|
||||
ftfy~=6.1.1
|
||||
modelcards~=0.1.6
|
||||
tensorboard
|
||||
tensorflow==2.11.0; sys_platform != 'darwin' or platform_machine != 'arm64'
|
||||
tensorflow-macos==2.11.0; sys_platform == 'darwin' and platform_machine == 'arm64'
|
||||
gitpython~=3.1.31
|
||||
lion-pytorch~=0.0.7
|
||||
mediapipe-silicon; sys_platform == 'darwin'
|
||||
mediapipe; sys_platform != 'darwin'
|
||||
modelcards~=0.1.6
|
||||
tensorboard
|
||||
tensorflow-macos==2.11.0; sys_platform == 'darwin' and platform_machine == 'arm64'
|
||||
tensorflow==2.11.0; sys_platform != 'darwin' or platform_machine != 'arm64'
|
||||
tqdm~=4.64.1
|
||||
transformers~=4.26.1
|
||||
discord-webhook~=1.1.0
|
||||
lion-pytorch~=0.0.7
|
||||
xformers==0.0.17.dev464
|
||||
|
|
@ -22,36 +22,26 @@ from starlette import status
|
|||
from starlette.requests import Request
|
||||
|
||||
try:
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import DreamState
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import create_model, generate_samples, \
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.secret import get_secret
|
||||
from dreambooth.shared import DreamState
|
||||
from dreambooth.ui_functions import create_model, generate_samples, \
|
||||
start_training
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, get_lora_models
|
||||
from dreambooth.utils.gen_utils import generate_classifiers
|
||||
from dreambooth.utils.image_utils import get_images
|
||||
from dreambooth.utils.model_utils import get_db_models, get_lora_models
|
||||
except:
|
||||
from dreambooth.dreambooth import shared # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
|
||||
from dreambooth.dreambooth.diff_to_sd import compile_checkpoint # noqa
|
||||
from dreambooth.dreambooth.secret import get_secret # noqa
|
||||
from dreambooth.dreambooth.shared import DreamState # noqa
|
||||
from dreambooth.dreambooth.ui_functions import create_model, generate_samples, start_training # noqa
|
||||
from dreambooth.dreambooth.utils.gen_utils import generate_classifiers # noqa
|
||||
from dreambooth.dreambooth.utils.image_utils import get_images # noqa
|
||||
from dreambooth.dreambooth.utils.model_utils import get_db_models, get_lora_models # noqa
|
||||
|
||||
pass
|
||||
print("Exception importing api")
|
||||
traceback.print_exc()
|
||||
|
||||
if os.environ.get("DEBUG_API", False):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
print("No, really, API loaded, wtf...")
|
||||
|
||||
class InstanceData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file or URL")
|
||||
|
|
@ -156,6 +146,7 @@ def file_to_base64(file_path) -> str:
|
|||
|
||||
|
||||
def dreambooth_api(_, app: FastAPI):
|
||||
print("API LOAD")
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
return JSONResponse(
|
||||
|
|
@ -941,4 +932,5 @@ try:
|
|||
script_callbacks.on_app_started(dreambooth_api)
|
||||
logger.debug("SD-Webui API layer loaded")
|
||||
except:
|
||||
logger.debug("Unable to import script callbacks.")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,23 +4,25 @@ from typing import List
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import save_config, from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret, create_secret, clear_secret
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status, get_launch_errors
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import performance_wizard, \
|
||||
from dreambooth.dataclasses import db_config
|
||||
from dreambooth.dataclasses.db_config import save_config, from_file
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.secret import get_secret, create_secret, clear_secret
|
||||
from dreambooth.shared import status, get_launch_errors
|
||||
from dreambooth.ui_functions import performance_wizard, \
|
||||
training_wizard, training_wizard_person, load_model_params, ui_classifiers, debug_buckets, create_model, \
|
||||
generate_samples, load_params, start_training, update_extension, start_crop
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, \
|
||||
from dreambooth.utils.image_utils import get_scheduler_names
|
||||
from dreambooth.utils.model_utils import get_db_models, \
|
||||
get_sorted_lora_models, get_model_snapshots
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention, \
|
||||
from dreambooth.utils.utils import list_attention, \
|
||||
list_floats, wrap_gpu_call, printm, list_optimizer
|
||||
from extensions.sd_dreambooth_extension.dreambooth.webhook import save_and_test_webhook
|
||||
from extensions.sd_dreambooth_extension.helpers.version_helper import check_updates
|
||||
from dreambooth.webhook import save_and_test_webhook
|
||||
from helpers.version_helper import check_updates
|
||||
from helpers.log_parser import LogParser
|
||||
from modules import script_callbacks, sd_models
|
||||
from modules.ui import gr_show, create_refresh_button
|
||||
from extensions.sd_dreambooth_extension.helpers.log_parser import LogParser
|
||||
|
||||
|
||||
params_to_save = []
|
||||
params_to_load = []
|
||||
|
|
@ -849,7 +851,6 @@ def on_ui_tabs():
|
|||
|
||||
ui_keys.append("db_status")
|
||||
params_to_load.append(db_status)
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
|
||||
db_config.save_keys = save_keys
|
||||
db_config.ui_keys = ui_keys
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue