Refactor, fix install/setup?

main
d8ahazard 2025-02-18 09:55:56 -06:00
parent 1b3257b46b
commit bae1e87c9b
31 changed files with 201 additions and 164 deletions

View File

@ -8,11 +8,11 @@ from typing import List, Dict
from pydantic import BaseModel
from dreambooth import shared # noqa
from dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dataclasses.ss_model_spec import build_metadata
from dreambooth.utils.image_utils import get_scheduler_names # noqa
from dreambooth.utils.utils import list_attention, select_precision, select_attention
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept # noqa
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.ss_model_spec import build_metadata
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names # noqa
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention, select_precision, select_attention
# Keys to save, replacing our dumb __init__ method
save_keys = []

View File

@ -1,6 +1,6 @@
from PIL import Image
from dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
class TrainResult:

View File

@ -1,7 +1,7 @@
import random
from typing import Tuple
from dreambooth.dataset.db_dataset import DbDataset
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
class BucketSampler:

View File

@ -3,11 +3,11 @@ import random
from torch.utils.data import Dataset
from dreambooth import shared
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import status
from dreambooth.utils.image_utils import FilenameTextGetter, \
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter, \
make_bucket_resolutions, \
sort_prompts, get_images
from helpers.mytqdm import mytqdm

View File

@ -10,12 +10,12 @@ import torch.utils.data
from torchvision.transforms import transforms
from transformers import CLIPTokenizer
from dreambooth import shared
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import status
from dreambooth.utils.image_utils import make_bucket_resolutions, \
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import make_bucket_resolutions, \
closest_resolution, shuffle_tags, open_and_trim
from dreambooth.utils.text_utils import build_strict_tokens
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import build_strict_tokens
from helpers.mytqdm import mytqdm
logger = logging.getLogger(__name__)

View File

@ -3,9 +3,9 @@ import random
from PIL import Image
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
closest_resolution, make_bucket_resolutions

View File

@ -15,13 +15,13 @@ import torch
from diffusers import UNet2DConditionModel
from torch import Tensor, nn
from dreambooth import shared as shared
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from dreambooth.shared import status
from dreambooth.utils.model_utils import unload_system_models, \
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
safe_unpickle_disabled, import_model_class_from_model_name_or_path
from dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import merge_lora_to_model

View File

@ -11,11 +11,11 @@ import traceback
import torch
from safetensors.torch import load_file, save_file
from dreambooth import shared
from dreambooth.dataclasses.db_config import from_file
from dreambooth.shared import status
from dreambooth.utils.model_utils import unload_system_models, reload_system_models
from dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from helpers import mytqdm
# =================#

View File

@ -25,8 +25,8 @@ import traceback
import torch
import torch.backends.cudnn
from dreambooth import shared
from dreambooth.utils.utils import cleanup
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
def should_reduce_batch_size(exception: Exception) -> bool:

View File

@ -20,12 +20,13 @@ import os
import shutil
import traceback
from typing import Union
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import safe_unpickle_disabled, \
unload_system_models, \
reload_system_models

View File

@ -1,7 +1,7 @@
import os
import secrets
from dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth import shared
db_path = os.path.join(shared.models_path, "dreambooth")
secret_file = os.path.join(db_path, "secret.txt")

View File

@ -40,34 +40,34 @@ from torch.nn.utils.parametrize import register_parametrization, remove_parametr
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from dreambooth import shared
from dreambooth.dataclasses.db_config import from_file
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataclasses.train_result import TrainResult
from dreambooth.dataset.bucket_sampler import BucketSampler
from dreambooth.dataset.db_dataset import DbDataset
from dreambooth.dataset.sample_dataset import SampleDataset
from dreambooth.deis_velocity import get_velocity
from dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
from dreambooth.memory import find_executable_batch_size
from dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
from dreambooth.shared import status
from dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
from dreambooth.utils.image_utils import db_save_image, get_scheduler_class
from dreambooth.utils.model_utils import (
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.train_result import TrainResult
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
from extensions.sd_dreambooth_extension.dreambooth.dataset.sample_dataset import SampleDataset
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
from extensions.sd_dreambooth_extension.dreambooth.diff_lora_to_sd_lora import convert_diffusers_to_kohya_lora
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_xl
from extensions.sd_dreambooth_extension.dreambooth.memory import find_executable_batch_size
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image, get_scheduler_class
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
unload_system_models,
import_model_class_from_model_name_or_path,
safe_unpickle_disabled,
xformerify,
torch2ify
)
from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
patch_accelerator_for_fp16_training)
from dreambooth.webhook import send_training_update
from dreambooth.xattention import optim_to
from extensions.sd_dreambooth_extension.dreambooth.webhook import send_training_update
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
from helpers.ema_model import EMAModel
from helpers.log_parser import LogParser
from helpers.mytqdm import mytqdm

View File

@ -12,10 +12,10 @@ from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.shared import status
from dreambooth.utils.image_utils import list_features, is_image
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
logger = get_logger(__name__)

View File

@ -18,17 +18,17 @@ from accelerate import find_executable_batch_size
from diffusers.utils import logging as dl
from torch.optim import AdamW
from dreambooth import shared
from dreambooth.dataclasses import db_config
from dreambooth.dataclasses.db_config import from_file, sanitize_name
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataset.bucket_sampler import BucketSampler
from dreambooth.dataset.class_dataset import ClassDataset
from dreambooth.optimization import UniversalScheduler
from dreambooth.sd_to_diff import extract_checkpoint
from dreambooth.shared import status, run
from dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
from dreambooth.utils.image_utils import (
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, sanitize_name
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataset.bucket_sampler import BucketSampler
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
from extensions.sd_dreambooth_extension.dreambooth.optimization import UniversalScheduler
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.shared import status, run
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_dataset, generate_classifiers
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import (
get_images,
db_save_image,
make_bucket_resolutions,
@ -36,7 +36,7 @@ from dreambooth.utils.image_utils import (
closest_resolution,
open_and_trim,
)
from dreambooth.utils.model_utils import (
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
unload_system_models,
reload_system_models,
get_lora_models,
@ -44,7 +44,7 @@ from dreambooth.utils.model_utils import (
get_model_snapshots,
LORA_SHARED_SRC_CREATE, get_db_models,
)
from dreambooth.utils.utils import printm, cleanup
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm, cleanup
from helpers.image_builder import ImageBuilder
from helpers.mytqdm import mytqdm
@ -720,18 +720,18 @@ def start_training(model_dir: str, class_gen_method: str = "Native Diffusers"):
status.textinfo = "Initializing imagic training..."
print(status.textinfo)
try:
from dreambooth.train_imagic import train_imagic # noqa
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
except:
from dreambooth.train_imagic import train_imagic # noqa
from extensions.sd_dreambooth_extension.dreambooth.train_imagic import train_imagic # noqa
result = train_imagic(config)
else:
status.textinfo = "Initializing dreambooth training..."
print(status.textinfo)
try:
from dreambooth.train_dreambooth import main # noqa
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
except:
from dreambooth.train_dreambooth import main # noqa
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main # noqa
result = main(class_gen_method=class_gen_method)
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -786,7 +786,7 @@ def reload_extension():
try:
from postinstall import actual_install # noqa
except:
from dreambooth.postinstall import actual_install # noqa
from extensions.sd_dreambooth_extension.dreambooth.postinstall import actual_install # noqa
actual_install()

View File

@ -10,14 +10,14 @@ try:
from core.handlers.status import StatusHandler
except:
pass
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.dataset.class_dataset import ClassDataset
from dreambooth.dataset.db_dataset import DbDataset
from dreambooth.shared import status
from dreambooth.utils.image_utils import db_save_image
from dreambooth.utils.utils import cleanup
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataset.class_dataset import ClassDataset
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import db_save_image
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
from helpers.image_builder import ImageBuilder
from helpers.mytqdm import mytqdm
@ -39,7 +39,7 @@ def generate_dataset(
data_cache=None):
if debug:
logger.debug("Generating dataset.")
from dreambooth.ui_functions import gr_update
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import gr_update
db_gallery = gr_update(value=None)
db_prompt_list = gr_update(value=None)

View File

@ -15,11 +15,11 @@ from PIL import features, PngImagePlugin, Image, ExifTags
from typing import List, Tuple, Dict, Union
import numpy as np
import torch
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from helpers.mytqdm import mytqdm
from dreambooth import shared
from dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.shared import status
def get_dim(filename, max_res):

View File

@ -13,9 +13,9 @@ import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import PretrainedConfig
from dreambooth import shared # noqa
from dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.utils.utils import cleanup # noqa
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup # noqa
from modules import hashes
from modules.safe import unsafe_torch_load, load

View File

@ -6,7 +6,7 @@ from typing import List
import torch
from transformers import CLIPTextModel
from dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
# Implementation from https://github.com/bmaltais/kohya_ss

View File

@ -12,14 +12,14 @@ from typing import Optional
import importlib_metadata
from packaging import version
from dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth import shared
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import torch
from huggingface_hub import HfFolder, whoami
from helpers.mytqdm import mytqdm
from dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.shared import status
def printi(msg, params=None, log=True):
@ -48,7 +48,7 @@ def sanitize_name(name):
def printm(msg=""):
from dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth import shared
use_logger = True
try:
from core.handlers.config import ConfigHandler

View File

@ -6,8 +6,8 @@ from typing import Union, List
import discord_webhook
from PIL import Image
from dreambooth import shared
from dreambooth.utils.image_utils import image_grid
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import image_grid
class DreamboothWebhookTarget(Enum):

View File

@ -11,12 +11,12 @@ from accelerate import Accelerator
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.utils import image_utils
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from dreambooth.utils.model_utils import get_checkpoint_match, \
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
from extensions.sd_dreambooth_extension.dreambooth.utils import image_utils
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_checkpoint_match, \
reload_system_models, \
safe_unpickle_disabled, unload_system_models
from helpers.mytqdm import mytqdm

View File

@ -9,7 +9,7 @@ from matplotlib import axes
from pandas import DataFrame
from pandas.plotting._matplotlib.style import get_standard_colors
from dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.shared import status
@dataclass
@ -245,7 +245,7 @@ class LogParser:
return pd.DataFrame(loss_events), pd.DataFrame(lr_events), pd.DataFrame(ram_events), has_all
try:
from dreambooth.dataclasses.db_config import from_file # noqa
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file # noqa
except:
from core.modules.dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
model_config = from_file(model_name)

View File

@ -2,7 +2,7 @@ from typing import Iterable
from tqdm import tqdm
from dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth import shared
class mytqdm(tqdm):

View File

@ -3,7 +3,7 @@ import os
import subprocess
from typing import Union, Dict
from dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth import shared
store_file = os.path.join(shared.dreambooth_models_path, "revision.txt")
change_file = os.path.join(shared.dreambooth_models_path, "changelog.txt")

View File

@ -8,7 +8,7 @@ from safetensors.torch import safe_open
from safetensors.torch import save_file as safe_save
from torch import dtype
from dreambooth.utils.model_utils import safe_unpickle_disabled
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import safe_unpickle_disabled
class LoraInjectedLinear(nn.Module):

View File

@ -17,10 +17,10 @@ from core.modules.base.module_base import BaseModule
from fastapi import FastAPI
import scripts.api
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from dreambooth.sd_to_diff import extract_checkpoint
from dreambooth.train_dreambooth import main
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.train_dreambooth import main
from module_src.gradio_parser import parse_gr_code
logger = logging.getLogger(__name__)
@ -56,13 +56,13 @@ class DreamboothModule(BaseModule):
async def _get_db_vars(request):
from dreambooth.utils.utils import (
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (
list_attention,
list_precisions,
list_optimizer,
list_schedulers,
)
from dreambooth.utils.image_utils import get_scheduler_names
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
attentions = list_attention()
precisions = list_precisions()

View File

@ -14,7 +14,17 @@ from importlib import metadata
from packaging.version import Version
from dreambooth import shared as db_shared
from extensions.sd_dreambooth_extension.dreambooth import shared as db_shared
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
# Detect CUDA version
try:
import torch
cuda_major, cuda_minor = torch.version.cuda.split('.')
torch_index_url = f"https://download.pytorch.org/whl/cu{cuda_major}{cuda_minor}"
except:
pass
if sys.version_info < (3, 8):
import importlib_metadata
@ -96,7 +106,7 @@ def is_installed(pkg: str, version: Optional[str] = None, check_strict: bool = T
try:
# Retrieve the package version from the installed package metadata
installed_version = metadata.version(pkg)
print(f"Installed version of {pkg}: {installed_version}")
print(f"[Installed version of {pkg}: {installed_version}")
# If version is not specified, just return True as the package is installed
if version is None:
return True
@ -137,7 +147,7 @@ def install_requirements():
if os.name == "darwin":
reqs.append("tensorboard==2.11.2")
else:
reqs.append("tensorboard==2.13.0")
reqs.append("tensorboard>=2.18.0")
for line in reqs:
try:
@ -173,7 +183,27 @@ def install_requirements():
error_msg = grepexc.stdout.decode()
print_requirement_installation_error(error_msg)
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
# Try importing scipy and numpy, and if they fail, re-install torch
try:
import numpy
import scipy
import pytorch_lightning
except ImportError:
print("Re-installing torch to ensure pytorch_lightning matches torch are installed.")
tc = torch_command + " --force-reinstall"
# Remove "pip install " from the command
tc = tc.replace("pip install ", "pytorch_lightning ")
pip_install(tc)
try:
import numpy
import scipy
except ImportError:
print("Failed to install numpy and scipy after re-installing torch.")
print("Please install numpy and scipy manually.")
print("pip install numpy scipy")
pass
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.48.3"):
print()
print("Does your project take forever to startup?")
print("Repetitive dependency installation may be the reason.")
@ -213,15 +243,15 @@ def check_bitsandbytes():
bitsandbytes_version = None
print("Checking bitsandbytes (ALL!)")
if bitsandbytes_version is None or "0.43.0" not in bitsandbytes_version:
if bitsandbytes_version is None or "0.45.2" not in bitsandbytes_version:
try:
print("Installing bitsandbytes")
pip_install("bitsandbytes==0.43.0", "--prefer-binary")
pip_install("bitsandbytes>=0.45.2", "--prefer-binary")
except:
print("Bitsandbytes 0.43.0 installation failed")
print("Bitsandbytes 0.45.2 installation failed")
print("Some features such as 8bit optimizers will be unavailable")
print("Install manually with")
print("'python -m pip install bitsandbytes==0.43.0 --prefer-binary --force-install'")
print("'python -m pip install bitsandbytes>=0.45.2 --prefer-binary --force-install'")
pass
@ -239,14 +269,12 @@ def check_versions():
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
dependencies = [
Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"),
Dependency(module="torchvision", version="0.14.1" if is_mac else "0.15.2+cu118"),
Dependency(module="accelerate", version="0.21.0"),
Dependency(module="diffusers", version="0.23.1")
Dependency(module="diffusers", version="0.32.2")
]
if device == "cuda":
dependencies.append(Dependency(module="bitsandbytes", version="0.43.0", required=False))
dependencies.append(Dependency(module="bitsandbytes", version="0.45.2", required=False))
if device != "mps":
dependencies.append(Dependency(module="xformers", version="0.0.21", required=False))
@ -257,11 +285,17 @@ def check_versions():
module = dependency.module
has_module = importlib.util.find_spec(module) is not None
installed_ver = importlib_metadata.version(module) if has_module else None
installed_ver = None
if has_module:
try:
installed_ver = importlib_metadata.version(module)
except:
pass
if not installed_ver:
module_msg = ""
if module != "xformers":
cmd_args = sys.argv
if module != "xformers" and "--xformers" in cmd_args:
launch_errors.append(f"{module} not installed.")
module_msg = "(Be sure to use the --xformers flag.)"

View File

@ -3,7 +3,7 @@ from typing import Tuple, List, Dict
import gradio as gr
from dreambooth.utils.image_utils import FilenameTextGetter
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter
image_data = []
@ -14,7 +14,7 @@ def load_image_data(input_path: str, recurse: bool = False) -> List[Dict[str, st
return []
global image_data
results = []
from dreambooth.utils.image_utils import list_features, is_image
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import list_features, is_image
pil_features = list_features()
# Get a list from PIL of all the image formats it supports

View File

@ -1,11 +1,13 @@
bitsandbytes>=0.43.0
bitsandbytes>=0.45.2
accelerate>=0.21.0
dadaptation>=3.2
diffusers>=0.25.0
discord-webhook==1.3.0
diffusers>=0.32.2
discord-webhook==1.3.1
fastapi
gitpython>=3.1.40
pytorch_optimizer==2.12.0
gitpython>=3.1.32
pytorch_optimizer==3.4.0
Pillow
transformers>=4.48.3
tqdm
tomesd>=0.1.2
tomesd>=0.1.3
xformers

View File

@ -22,18 +22,18 @@ from starlette import status
from starlette.requests import Request
try:
from dreambooth import shared
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
from dreambooth.secret import get_secret
from dreambooth.shared import DreamState
from dreambooth.ui_functions import create_model, generate_samples, \
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
from extensions.sd_dreambooth_extension.dreambooth.shared import DreamState
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import create_model, generate_samples, \
start_training
from dreambooth.utils.gen_utils import generate_classifiers
from dreambooth.utils.image_utils import get_images
from dreambooth.utils.model_utils import get_db_models, get_lora_models
from extensions.sd_dreambooth_extension.dreambooth.utils.gen_utils import generate_classifiers
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import get_db_models, get_lora_models
except:
print("Exception importing api")
traceback.print_exc()

View File

@ -5,19 +5,19 @@ from typing import List
import gradio as gr
from dreambooth.dataclasses.db_config import from_file, save_config
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
from dreambooth.secret import (
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, save_config
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
from extensions.sd_dreambooth_extension.dreambooth.secret import (
get_secret,
create_secret,
clear_secret,
)
from dreambooth.shared import (
from extensions.sd_dreambooth_extension.dreambooth.shared import (
status,
get_launch_errors,
)
from dreambooth.ui_functions import (
from extensions.sd_dreambooth_extension.dreambooth.ui_functions import (
performance_wizard,
load_model_params,
ui_classifiers,
@ -29,16 +29,16 @@ from dreambooth.ui_functions import (
update_extension,
start_crop,
)
from dreambooth.utils.image_utils import (
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import (
get_scheduler_names,
)
from dreambooth.utils.model_utils import (
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import (
get_db_models,
get_sorted_lora_models,
get_model_snapshots,
get_shared_models,
)
from dreambooth.utils.utils import (
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import (
list_attention,
list_precisions,
wrap_gpu_call,
@ -46,7 +46,7 @@ from dreambooth.utils.utils import (
list_optimizer,
list_schedulers, select_precision, select_attention,
)
from dreambooth.webhook import save_and_test_webhook
from extensions.sd_dreambooth_extension.dreambooth.webhook import save_and_test_webhook
from helpers.log_parser import LogParser
from helpers.version_helper import check_updates
from modules import script_callbacks, sd_models
@ -1069,7 +1069,7 @@ def on_ui_tabs():
with gr.Column() as db_hook_view:
gr.HTML(value="Webhooks")
# In the future change this to something more generic and list the supported types
# from DreamboothWebhookTarget enum; for now, Discord is what I use ;)
# from extensions.sd_dreambooth_extension.dreamboothWebhookTarget enum; for now, Discord is what I use ;)
# Add options to include notifications on training complete and exceptions that halt training
db_notification_webhook_url = gr.Textbox(
label="Discord Webhook",
@ -1660,7 +1660,7 @@ def on_ui_tabs():
ui_keys.append("db_status")
params_to_load.append(db_status)
from dreambooth.dataclasses import db_config
from extensions.sd_dreambooth_extension.dreambooth.dataclasses import db_config
db_config.save_keys = save_keys
db_config.ui_keys = ui_keys