Refactor, fix install/setup?
parent
1b3257b46b
commit
bae1e87c9b
|
|
@ -8,11 +8,11 @@ from typing import List, Dict
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dreambooth import shared # noqa
|
||||
from dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from dreambooth.dataclasses.ss_model_spec import build_metadata
|
||||
from dreambooth.utils.image_utils import get_scheduler_names # noqa
|
||||
from dreambooth.utils.utils import list_attention, select_precision, select_attention
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.ss_model_spec import build_metadata
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention, select_precision, select_attention
|
||||
|
||||
# Keys to save, replacing our dumb __init__ method
|
||||
save_keys = []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from PIL import Image
|
||||
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
|
||||
|
||||
class TrainResult:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import random
|
||||
from typing import Tuple
|
||||
|
||||
from dreambooth.dataset.db_dataset import DbDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
|
||||
|
||||
|
||||
class BucketSampler:
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import random
|
|||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import FilenameTextGetter, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter, \
|
||||
make_bucket_resolutions, \
|
||||
sort_prompts, get_images
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ import torch.utils.data
|
|||
from torchvision.transforms import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import make_bucket_resolutions, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import make_bucket_resolutions, \
|
||||
closest_resolution, shuffle_tags, open_and_trim
|
||||
from dreambooth.utils.text_utils import build_strict_tokens
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import build_strict_tokens
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import random
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
|
||||
closest_resolution, make_bucket_resolutions
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ import torch
|
|||
from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor, nn
|
||||
|
||||
from dreambooth import shared as shared
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.model_utils import unload_system_models, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
|
||||
reload_system_models, \
|
||||
safe_unpickle_disabled, import_model_class_from_model_name_or_path
|
||||
from dreambooth.utils.utils import printi
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
|
||||
from helpers.mytqdm import mytqdm
|
||||
from lora_diffusion.lora import merge_lora_to_model
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ import traceback
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import from_file
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.model_utils import unload_system_models, reload_system_models
|
||||
from dreambooth.utils.utils import printi
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
|
||||
from helpers import mytqdm
|
||||
|
||||
# =================#
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ import traceback
|
|||
import torch
|
||||
import torch.backends.cudnn
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.utils.utils import cleanup
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
|
||||
|
||||
|
||||
def should_reduce_batch_size(exception: Exception) -> bool:
|
||||
|
|
|
|||
|
|
@ -20,12 +20,13 @@ import os
|
|||
import shutil
|
||||
import traceback
|
||||
from typing import Union
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import safe_unpickle_disabled, \
|
||||
unload_system_models, \
|
||||
reload_system_models
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import secrets
|
||||
|
||||
from dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
|
||||
db_path = os.path.join(shared.models_path, "dreambooth")
|
||||
secret_file = os.path.join(db_path, "secret.txt")
|
||||
|
|
|
|||
|
|
@ -40,34 +40,34 @@ from torch.nn.utils.parametrize import register_parametrization, remove_parametr
|
|||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import from_file
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataclasses.train_result import TrainResult
|
||||
from dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from dreambooth.dataset.db_dataset import DbDataset
|
||||
from dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from dreambooth.deis_velocity import get_velocity
|
||||
from dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
|
||||
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
|
||||
from dreambooth.memory import find_executable_batch_size
|
||||
from dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
|
||||
from dreambooth.utils.image_utils import db_save_image, get_scheduler_class
|
||||
from dreambooth.utils.model_utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.train_result import TrainResult
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.sample_dataset import SampleDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
|
||||
from extensions.sd_dreambooth_extension.dreambooth.memory import find_executable_batch_size
|
||||
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image, get_scheduler_class
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
|
||||
unload_system_models,
|
||||
import_model_class_from_model_name_or_path,
|
||||
safe_unpickle_disabled,
|
||||
xformerify,
|
||||
torch2ify
|
||||
)
|
||||
from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
|
||||
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
|
||||
patch_accelerator_for_fp16_training)
|
||||
from dreambooth.webhook import send_training_update
|
||||
from dreambooth.xattention import optim_to
|
||||
from extensions.sd_dreambooth_extension.dreambooth.webhook import send_training_update
|
||||
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
|
||||
from helpers.ema_model import EMAModel
|
||||
from helpers.log_parser import LogParser
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from torchvision import transforms
|
|||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import list_features, is_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ from accelerate import find_executable_batch_size
|
|||
from diffusers.utils import logging as dl
|
||||
from torch.optim import AdamW
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses import db_config
|
||||
from dreambooth.dataclasses.db_config import from_file, sanitize_name
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from dreambooth.dataset.class_dataset import ClassDataset
|
||||
from dreambooth.optimization import UniversalScheduler
|
||||
from dreambooth.sd_to_diff import extract_checkpoint
|
||||
from dreambooth.shared import status, run
|
||||
from dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
|
||||
from dreambooth.utils.image_utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, sanitize_name
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
|
||||
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status, run
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import (
|
||||
get_images,
|
||||
db_save_image,
|
||||
make_bucket_resolutions,
|
||||
|
|
@ -36,7 +36,7 @@ from dreambooth.utils.image_utils import (
|
|||
closest_resolution,
|
||||
open_and_trim,
|
||||
)
|
||||
from dreambooth.utils.model_utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
|
||||
unload_system_models,
|
||||
reload_system_models,
|
||||
get_lora_models,
|
||||
|
|
@ -44,7 +44,7 @@ from dreambooth.utils.model_utils import (
|
|||
get_model_snapshots,
|
||||
LORA_SHARED_SRC_CREATE, get_db_models,
|
||||
)
|
||||
from dreambooth.utils.utils import printm, cleanup
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm, cleanup
|
||||
from helpers.image_builder import ImageBuilder
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
|
|
@ -720,18 +720,18 @@ def start_training(model_dir: str, class_gen_method: str = "Native Diffusers"):
|
|||
status.textinfo = "Initializing imagic training..."
|
||||
print(status.textinfo)
|
||||
try:
|
||||
from dreambooth.train_imagic import train_imagic # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
|
||||
except:
|
||||
from dreambooth.train_imagic import train_imagic # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
|
||||
|
||||
result = train_imagic(config)
|
||||
else:
|
||||
status.textinfo = "Initializing dreambooth training..."
|
||||
print(status.textinfo)
|
||||
try:
|
||||
from dreambooth.train_dreambooth import main # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
|
||||
except:
|
||||
from dreambooth.train_dreambooth import main # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
|
||||
result = main(class_gen_method=class_gen_method)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -786,7 +786,7 @@ def reload_extension():
|
|||
try:
|
||||
from postinstall import actual_install # noqa
|
||||
except:
|
||||
from dreambooth.postinstall import actual_install # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.postinstall import actual_install # noqa
|
||||
|
||||
actual_install()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ try:
|
|||
from core.handlers.status import StatusHandler
|
||||
except:
|
||||
pass
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.dataset.class_dataset import ClassDataset
|
||||
from dreambooth.dataset.db_dataset import DbDataset
|
||||
from dreambooth.shared import status
|
||||
from dreambooth.utils.image_utils import db_save_image
|
||||
from dreambooth.utils.utils import cleanup
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
|
||||
from helpers.image_builder import ImageBuilder
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ def generate_dataset(
|
|||
data_cache=None):
|
||||
if debug:
|
||||
logger.debug("Generating dataset.")
|
||||
from dreambooth.ui_functions import gr_update
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
|
||||
|
||||
db_gallery = gr_update(value=None)
|
||||
db_prompt_list = gr_update(value=None)
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ from PIL import features, PngImagePlugin, Image, ExifTags
|
|||
from typing import List, Tuple, Dict, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from helpers.mytqdm import mytqdm
|
||||
from dreambooth import shared
|
||||
from dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
|
||||
|
||||
def get_dim(filename, max_res):
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ import torch
|
|||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from dreambooth import shared # noqa
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from dreambooth.utils.utils import cleanup # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup # noqa
|
||||
from modules import hashes
|
||||
from modules.safe import unsafe_torch_load, load
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
|
||||
|
||||
# Implementation from https://github.com/bmaltais/kohya_ss
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ from typing import Optional
|
|||
import importlib_metadata
|
||||
from packaging import version
|
||||
|
||||
from dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import torch
|
||||
from huggingface_hub import HfFolder, whoami
|
||||
|
||||
from helpers.mytqdm import mytqdm
|
||||
from dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
|
||||
|
||||
def printi(msg, params=None, log=True):
|
||||
|
|
@ -48,7 +48,7 @@ def sanitize_name(name):
|
|||
|
||||
|
||||
def printm(msg=""):
|
||||
from dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
use_logger = True
|
||||
try:
|
||||
from core.handlers.config import ConfigHandler
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from typing import Union, List
|
|||
import discord_webhook
|
||||
from PIL import Image
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.utils.image_utils import image_grid
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import image_grid
|
||||
|
||||
|
||||
class DreamboothWebhookTarget(Enum):
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ from accelerate import Accelerator
|
|||
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from dreambooth.utils import image_utils
|
||||
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
|
||||
from dreambooth.utils.model_utils import get_checkpoint_match, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import image_utils
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_checkpoint_match, \
|
||||
reload_system_models, \
|
||||
safe_unpickle_disabled, unload_system_models
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from matplotlib import axes
|
|||
from pandas import DataFrame
|
||||
from pandas.plotting._matplotlib.style import get_standard_colors
|
||||
|
||||
from dreambooth.shared import status
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import status
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -245,7 +245,7 @@ class LogParser:
|
|||
return pd.DataFrame(loss_events), pd.DataFrame(lr_events), pd.DataFrame(ram_events), has_all
|
||||
|
||||
try:
|
||||
from dreambooth.dataclasses.db_config import from_file # noqa
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file # noqa
|
||||
except:
|
||||
from core.modules.dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
|
||||
model_config = from_file(model_name)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Iterable
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
|
||||
|
||||
class mytqdm(tqdm):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import subprocess
|
||||
from typing import Union, Dict
|
||||
|
||||
from dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
|
||||
store_file = os.path.join(shared.dreambooth_models_path, "revision.txt")
|
||||
change_file = os.path.join(shared.dreambooth_models_path, "changelog.txt")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from safetensors.torch import safe_open
|
|||
from safetensors.torch import save_file as safe_save
|
||||
from torch import dtype
|
||||
|
||||
from dreambooth.utils.model_utils import safe_unpickle_disabled
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import safe_unpickle_disabled
|
||||
|
||||
|
||||
class LoraInjectedLinear(nn.Module):
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ from core.modules.base.module_base import BaseModule
|
|||
from fastapi import FastAPI
|
||||
|
||||
import scripts.api
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from dreambooth.sd_to_diff import extract_checkpoint
|
||||
from dreambooth.train_dreambooth import main
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
|
||||
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
|
||||
from module_src.gradio_parser import parse_gr_code
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -56,13 +56,13 @@ class DreamboothModule(BaseModule):
|
|||
|
||||
|
||||
async def _get_db_vars(request):
|
||||
from dreambooth.utils.utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (
|
||||
list_attention,
|
||||
list_precisions,
|
||||
list_optimizer,
|
||||
list_schedulers,
|
||||
)
|
||||
from dreambooth.utils.image_utils import get_scheduler_names
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
|
||||
|
||||
attentions = list_attention()
|
||||
precisions = list_precisions()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,17 @@ from importlib import metadata
|
|||
|
||||
from packaging.version import Version
|
||||
|
||||
from dreambooth import shared as db_shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared as db_shared
|
||||
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||
# Detect CUDA version
|
||||
try:
|
||||
import torch
|
||||
cuda_major, cuda_minor = torch.version.cuda.split('.')
|
||||
torch_index_url = f"https://download.pytorch.org/whl/cu{cuda_major}{cuda_minor}"
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
|
|
@ -96,7 +106,7 @@ def is_installed(pkg: str, version: Optional[str] = None, check_strict: bool = T
|
|||
try:
|
||||
# Retrieve the package version from the installed package metadata
|
||||
installed_version = metadata.version(pkg)
|
||||
print(f"Installed version of {pkg}: {installed_version}")
|
||||
print(f"[Installed version of {pkg}: {installed_version}")
|
||||
# If version is not specified, just return True as the package is installed
|
||||
if version is None:
|
||||
return True
|
||||
|
|
@ -137,7 +147,7 @@ def install_requirements():
|
|||
if os.name == "darwin":
|
||||
reqs.append("tensorboard==2.11.2")
|
||||
else:
|
||||
reqs.append("tensorboard==2.13.0")
|
||||
reqs.append("tensorboard>=2.18.0")
|
||||
|
||||
for line in reqs:
|
||||
try:
|
||||
|
|
@ -173,7 +183,27 @@ def install_requirements():
|
|||
error_msg = grepexc.stdout.decode()
|
||||
print_requirement_installation_error(error_msg)
|
||||
|
||||
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
|
||||
# Try importing scipy and numpy, and if they fail, re-install torch
|
||||
try:
|
||||
import numpy
|
||||
import scipy
|
||||
import pytorch_lightning
|
||||
except ImportError:
|
||||
print("Re-installing torch to ensure pytorch_lightning matches torch are installed.")
|
||||
tc = torch_command + " --force-reinstall"
|
||||
# Remove "pip install " from the command
|
||||
tc = tc.replace("pip install ", "pytorch_lightning ")
|
||||
pip_install(tc)
|
||||
try:
|
||||
import numpy
|
||||
import scipy
|
||||
except ImportError:
|
||||
print("Failed to install numpy and scipy after re-installing torch.")
|
||||
print("Please install numpy and scipy manually.")
|
||||
print("pip install numpy scipy")
|
||||
pass
|
||||
|
||||
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.48.3"):
|
||||
print()
|
||||
print("Does your project take forever to startup?")
|
||||
print("Repetitive dependency installation may be the reason.")
|
||||
|
|
@ -213,15 +243,15 @@ def check_bitsandbytes():
|
|||
bitsandbytes_version = None
|
||||
|
||||
print("Checking bitsandbytes (ALL!)")
|
||||
if bitsandbytes_version is None or "0.43.0" not in bitsandbytes_version:
|
||||
if bitsandbytes_version is None or "0.45.2" not in bitsandbytes_version:
|
||||
try:
|
||||
print("Installing bitsandbytes")
|
||||
pip_install("bitsandbytes==0.43.0", "--prefer-binary")
|
||||
pip_install("bitsandbytes>=0.45.2", "--prefer-binary")
|
||||
except:
|
||||
print("Bitsandbytes 0.43.0 installation failed")
|
||||
print("Bitsandbytes 0.45.2 installation failed")
|
||||
print("Some features such as 8bit optimizers will be unavailable")
|
||||
print("Install manually with")
|
||||
print("'python -m pip install bitsandbytes==0.43.0 --prefer-binary --force-install'")
|
||||
print("'python -m pip install bitsandbytes>=0.45.2 --prefer-binary --force-install'")
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -239,14 +269,12 @@ def check_versions():
|
|||
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
|
||||
|
||||
dependencies = [
|
||||
Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"),
|
||||
Dependency(module="torchvision", version="0.14.1" if is_mac else "0.15.2+cu118"),
|
||||
Dependency(module="accelerate", version="0.21.0"),
|
||||
Dependency(module="diffusers", version="0.23.1")
|
||||
Dependency(module="diffusers", version="0.32.2")
|
||||
]
|
||||
|
||||
if device == "cuda":
|
||||
dependencies.append(Dependency(module="bitsandbytes", version="0.43.0", required=False))
|
||||
dependencies.append(Dependency(module="bitsandbytes", version="0.45.2", required=False))
|
||||
|
||||
if device != "mps":
|
||||
dependencies.append(Dependency(module="xformers", version="0.0.21", required=False))
|
||||
|
|
@ -257,11 +285,17 @@ def check_versions():
|
|||
module = dependency.module
|
||||
|
||||
has_module = importlib.util.find_spec(module) is not None
|
||||
installed_ver = importlib_metadata.version(module) if has_module else None
|
||||
installed_ver = None
|
||||
if has_module:
|
||||
try:
|
||||
installed_ver = importlib_metadata.version(module)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not installed_ver:
|
||||
module_msg = ""
|
||||
if module != "xformers":
|
||||
cmd_args = sys.argv
|
||||
if module != "xformers" and "--xformers" in cmd_args:
|
||||
launch_errors.append(f"{module} not installed.")
|
||||
module_msg = "(Be sure to use the --xformers flag.)"
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Tuple, List, Dict
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from dreambooth.utils.image_utils import FilenameTextGetter
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter
|
||||
|
||||
image_data = []
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ def load_image_data(input_path: str, recurse: bool = False) -> List[Dict[str, st
|
|||
return []
|
||||
global image_data
|
||||
results = []
|
||||
from dreambooth.utils.image_utils import list_features, is_image
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
|
||||
pil_features = list_features()
|
||||
# Get a list from PIL of all the image formats it supports
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
bitsandbytes>=0.43.0
|
||||
bitsandbytes>=0.45.2
|
||||
accelerate>=0.21.0
|
||||
dadaptation>=3.2
|
||||
diffusers>=0.25.0
|
||||
discord-webhook==1.3.0
|
||||
diffusers>=0.32.2
|
||||
discord-webhook==1.3.1
|
||||
fastapi
|
||||
gitpython>=3.1.40
|
||||
pytorch_optimizer==2.12.0
|
||||
gitpython>=3.1.32
|
||||
pytorch_optimizer==3.4.0
|
||||
Pillow
|
||||
transformers>=4.48.3
|
||||
tqdm
|
||||
tomesd>=0.1.2
|
||||
tomesd>=0.1.3
|
||||
xformers
|
||||
|
|
@ -22,18 +22,18 @@ from starlette import status
|
|||
from starlette.requests import Request
|
||||
|
||||
try:
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from dreambooth.secret import get_secret
|
||||
from dreambooth.shared import DreamState
|
||||
from dreambooth.ui_functions import create_model, generate_samples, \
|
||||
from extensions.sd_dreambooth_extension.dreambooth import shared
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import DreamState
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import create_model, generate_samples, \
|
||||
start_training
|
||||
from dreambooth.utils.gen_utils import generate_classifiers
|
||||
from dreambooth.utils.image_utils import get_images
|
||||
from dreambooth.utils.model_utils import get_db_models, get_lora_models
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, get_lora_models
|
||||
except:
|
||||
print("Exception importing api")
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -5,19 +5,19 @@ from typing import List
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from dreambooth.dataclasses.db_config import from_file, save_config
|
||||
from dreambooth.diff_to_sd import compile_checkpoint
|
||||
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from dreambooth.secret import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, save_config
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
||||
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
|
||||
from extensions.sd_dreambooth_extension.dreambooth.secret import (
|
||||
get_secret,
|
||||
create_secret,
|
||||
clear_secret,
|
||||
)
|
||||
from dreambooth.shared import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.shared import (
|
||||
status,
|
||||
get_launch_errors,
|
||||
)
|
||||
from dreambooth.ui_functions import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import (
|
||||
performance_wizard,
|
||||
load_model_params,
|
||||
ui_classifiers,
|
||||
|
|
@ -29,16 +29,16 @@ from dreambooth.ui_functions import (
|
|||
update_extension,
|
||||
start_crop,
|
||||
)
|
||||
from dreambooth.utils.image_utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import (
|
||||
get_scheduler_names,
|
||||
)
|
||||
from dreambooth.utils.model_utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
|
||||
get_db_models,
|
||||
get_sorted_lora_models,
|
||||
get_model_snapshots,
|
||||
get_shared_models,
|
||||
)
|
||||
from dreambooth.utils.utils import (
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (
|
||||
list_attention,
|
||||
list_precisions,
|
||||
wrap_gpu_call,
|
||||
|
|
@ -46,7 +46,7 @@ from dreambooth.utils.utils import (
|
|||
list_optimizer,
|
||||
list_schedulers, select_precision, select_attention,
|
||||
)
|
||||
from dreambooth.webhook import save_and_test_webhook
|
||||
from extensions.sd_dreambooth_extension.dreambooth.webhook import save_and_test_webhook
|
||||
from helpers.log_parser import LogParser
|
||||
from helpers.version_helper import check_updates
|
||||
from modules import script_callbacks, sd_models
|
||||
|
|
@ -1069,7 +1069,7 @@ def on_ui_tabs():
|
|||
with gr.Column() as db_hook_view:
|
||||
gr.HTML(value="Webhooks")
|
||||
# In the future change this to something more generic and list the supported types
|
||||
# from DreamboothWebhookTarget enum; for now, Discord is what I use ;)
|
||||
# from extensions.sd_dreambooth_extension.dreamboothWebhookTarget enum; for now, Discord is what I use ;)
|
||||
# Add options to include notifications on training complete and exceptions that halt training
|
||||
db_notification_webhook_url = gr.Textbox(
|
||||
label="Discord Webhook",
|
||||
|
|
@ -1660,7 +1660,7 @@ def on_ui_tabs():
|
|||
|
||||
ui_keys.append("db_status")
|
||||
params_to_load.append(db_status)
|
||||
from dreambooth.dataclasses import db_config
|
||||
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
|
||||
db_config.save_keys = save_keys
|
||||
db_config.ui_keys = ui_keys
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue