sd_dreambooth_extension/postinstall.py

426 lines
16 KiB
Python

import importlib.util
import json
import os
import subprocess
import sys
from dataclasses import dataclass
from typing import Optional
import git
import torch
from packaging import version as pv
from importlib import metadata
from packaging.version import Version
from extensions.sd_dreambooth_extension.dreambooth import shared as db_shared
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
# Detect CUDA version
try:
import torch
cuda_major, cuda_minor = torch.version.cuda.split('.')
torch_index_url = f"https://download.pytorch.org/whl/cu{cuda_major}{cuda_minor}"
except:
pass
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
try:
if torch.backends.mps.is_built():
torch.zeros(1).to(torch.device("mps"))
device = "mps"
except Exception:
pass
def actual_install():
if os.environ.get("PUBLIC_KEY", None):
print("Docker, returning.")
db_shared.launch_error = None
return
base_dir = os.path.dirname(os.path.realpath(__file__))
try:
repo = git.Repo(base_dir)
revision = repo.rev_parse("HEAD")
except:
revision = ""
print("If submitting an issue on github, please provide the full startup log for debugging purposes.")
print("")
print("Initializing Dreambooth")
print(f"Dreambooth revision: {revision}")
check_xformers()
# Only do this if we're using CUDA and not AMD
if device == "cuda":
check_bitsandbytes()
install_requirements()
check_versions()
check_torch_unsafe_load()
def pip_install(*args):
try:
output = subprocess.check_output(
[sys.executable, "-m", "pip", "install"] + list(args),
stderr=subprocess.STDOUT,
)
success = False
for line in output.decode().split("\n"):
if "Successfully installed" in line:
print(line)
success = True
return success
except subprocess.CalledProcessError as e:
print("Error occurred:", e.output.decode())
return False
def pip_uninstall(*args):
output = subprocess.check_output(
[sys.executable, "-m", "pip", "uninstall", "-y"] + list(args),
stderr=subprocess.STDOUT,
)
for line in output.decode().split("\n"):
if "Successfully uninstalled" in line:
print(line)
def is_installed(pkg: str, version: Optional[str] = None, check_strict: bool = True) -> bool:
try:
# Retrieve the package version from the installed package metadata
installed_version = metadata.version(pkg)
print(f"[Installed version of {pkg}: {installed_version}")
# If version is not specified, just return True as the package is installed
if version is None:
return True
# Compare the installed version with the required version
if check_strict:
# Strict comparison (must be an exact match)
return pv.parse(installed_version) == pv.parse(version)
else:
# Non-strict comparison (installed version must be greater than or equal to the required version)
return pv.parse(installed_version) >= pv.parse(version)
except metadata.PackageNotFoundError:
# The package is not installed
return False
except Exception as e:
# Any other exceptions encountered
print(f"Error: {e}")
return False
def install_requirements():
dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False)
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
req_file_startup_arg = os.environ.get("REQS_FILE", "requirements_versions.txt")
if dreambooth_skip_install or req_file == req_file_startup_arg:
return
print("Checking Dreambooth requirements...")
has_diffusers = importlib.util.find_spec("diffusers") is not None
has_tqdm = importlib.util.find_spec("tqdm") is not None
transformers_version = importlib_metadata.version("transformers")
non_strict_separators = ["==", ">=", "<=", ">", "<", "~="]
# Load the requirements file
with open(req_file, "r") as f:
reqs = f.readlines()
if os.name == "darwin":
reqs.append("tensorboard==2.11.2")
else:
reqs.append("tensorboard>=2.18.0")
for line in reqs:
try:
package = line.strip()
if package and not package.startswith("#"):
package_version = None
strict = "==" in package
for separator in non_strict_separators:
if separator in package:
strict = separator == "=="
parts = line.split(separator)
if len(parts) < 2:
print(f"Invalid requirement: {line}")
continue
package = parts[0].strip()
package_version = parts[1].strip()
if "#" in package_version:
package_version = package_version.split("#")[0]
package = package.strip()
package_version = package_version.strip()
break
if "#" in package:
package = package.split("#")[0]
package = package.strip()
v_string = "" if not package_version else f" v{package_version}"
if not is_installed(package, package_version, strict):
print(f"[Dreambooth] {package}{v_string} is not installed.")
pip_install(line)
else:
print(f"[Dreambooth] {package}{v_string} is already installed.")
except subprocess.CalledProcessError as grepexc:
error_msg = grepexc.stdout.decode()
print_requirement_installation_error(error_msg)
# Try importing scipy and numpy, and if they fail, re-install torch
try:
import numpy
import scipy
import pytorch_lightning
except ImportError:
print("Re-installing torch to ensure pytorch_lightning matches torch are installed.")
tc = torch_command + " --force-reinstall"
# Remove "pip install " from the command
tc = tc.replace("pip install ", "pytorch_lightning ")
pip_install(tc)
try:
import numpy
import scipy
except ImportError:
print("Failed to install numpy and scipy after re-installing torch.")
print("Please install numpy and scipy manually.")
print("pip install numpy scipy")
pass
# Removed outdated transformers version warning that doesn't match requirements.txt
def check_xformers():
"""
Install xformers if necessary
"""
print("Checking xformers...")
try:
xformers_version = importlib_metadata.version("xformers")
xformers_outdated = Version(xformers_version) < Version("0.0.27")
# Parse arguments, see if --xformers is passed
from modules import shared
cmd_opts = shared.cmd_opts
if cmd_opts.xformers or cmd_opts.reinstall_xformers:
if xformers_outdated:
print("Xformers is outdated, and automatic installation has been removed. Please install manually if desired.")
except:
pass
def check_bitsandbytes():
"""
Check for "different" B&B Files and copy only if necessary
"""
print("Checking bitsandbytes...")
try:
bitsandbytes_version = importlib_metadata.version("bitsandbytes")
except:
bitsandbytes_version = None
print("Checking bitsandbytes (ALL!)")
if bitsandbytes_version is None or (bitsandbytes_version and Version(bitsandbytes_version) < Version("0.45.2")):
try:
print("Installing bitsandbytes")
pip_install("bitsandbytes>=0.45.2", "--prefer-binary")
except:
print("Bitsandbytes 0.45.2 installation failed")
print("Some features such as 8bit optimizers will be unavailable")
print("Install manually with")
print("'python -m pip install bitsandbytes>=0.45.2 --prefer-binary --force-install'")
pass
@dataclass
class Dependency:
module: str
version: str
version_comparison: str = "min"
required: bool = True
def check_versions():
import platform
from sys import platform as sys_platform
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
dependencies = [
Dependency(module="accelerate", version="0.21.0"),
Dependency(module="diffusers", version="0.32.2"),
Dependency(module="transformers", version="4.49.0")
]
if device == "cuda":
dependencies.append(Dependency(module="bitsandbytes", version="0.45.2", required=False))
if device != "mps":
dependencies.append(Dependency(module="xformers", version="0.0.27", required=False))
launch_errors = []
for dependency in dependencies:
module = dependency.module
has_module = importlib.util.find_spec(module) is not None
installed_ver = None
if has_module:
try:
installed_ver = importlib_metadata.version(module)
except:
pass
if not installed_ver:
module_msg = ""
cmd_args = sys.argv
if module != "xformers" and "--xformers" in cmd_args:
launch_errors.append(f"{module} not installed.")
module_msg = "(Be sure to use the --xformers flag.)"
print(f"[!] {module} NOT installed.{module_msg}")
continue
required_version = dependency.version
required_comparison = dependency.version_comparison
if required_comparison == "min" and Version(installed_ver) < Version(required_version):
if dependency.required:
launch_errors.append(f"{module} is below the required {required_version} version.")
print(f"[!] {module} version {installed_ver} installed.")
elif required_comparison == "exact" and Version(installed_ver) != Version(required_version):
launch_errors.append(f"{module} is not the required {required_version} version.")
print(f"[!] {module} version {installed_ver} installed.")
else:
print(f"[+] {module} version {installed_ver} installed.")
try:
if len(launch_errors):
print_launch_errors(launch_errors)
os.environ["ERRORS"] = json.dumps(launch_errors)
else:
os.environ["ERRORS"] = ""
except Exception as e:
print(e)
def print_requirement_installation_error(err):
print("# Requirement installation exception:")
for line in err.split('\n'):
line = line.strip()
if line:
print(line)
def print_bitsandbytes_installation_error(err):
print()
print("#######################################################################################################")
print("# BITSANDBYTES ISSUE DETECTED #")
print("#######################################################################################################")
print("#")
print("# Dreambooth could not find a compatible version of bitsandbytes.")
print("# bitsandbytes will not be available for Dreambooth.")
print("#")
print("# Bitsandbytes installation exception:")
for line in err.split('\n'):
line = line.strip()
if line:
print(line)
print("#")
print("# TO FIX THIS ISSUE, DO THE FOLLOWING:")
print("# 1. Fully restart your project (not just the webpage)")
print("# 2. Running the following commands from the A1111 project root:")
print("cd venv/Scripts")
print("activate")
print("cd ../..")
print("# WINDOWS ONLY: ")
print(
"pip install bitsandbytes>=0.45.2 --prefer-binary --force-reinstall")
print("#######################################################################################################")
def print_xformers_installation_error(err):
torch_ver = importlib_metadata.version("torch")
print()
print("#######################################################################################################")
print("# XFORMERS ISSUE DETECTED #")
print("#######################################################################################################")
print("#")
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.27 built with torch {torch_ver})")
print("# xformers will not be available for Dreambooth. Consider upgrading to Torch 2.")
print("#")
print("# Xformers installation exception:")
for line in err.split('\n'):
line = line.strip()
if line:
print(line)
print("#")
print("#######################################################################################################")
def print_launch_errors(launch_errors):
print()
print("#######################################################################################################")
print("# LIBRARY ISSUE DETECTED #")
print("#######################################################################################################")
print("#")
print("# " + "\n# ".join(launch_errors))
print("#")
print("# Dreambooth may not work properly.")
print("#")
print("# TROUBLESHOOTING")
print("# 1. Fully restart your project (not just the webpage)")
print("# 2. Update your A1111 project and extensions")
print("# 3. Dreambooth requirements should have installed automatically, but you can manually install them")
print("# by running the following commands from the A1111 project root:")
print("cd venv/Scripts")
print("activate")
print("cd ../..")
print("pip install -r ./extensions/sd_dreambooth_extension/requirements.txt")
print(
"pip install bitsandbytes>=0.45.2 --prefer-binary --force-reinstall")
print("#######################################################################################################")
def check_torch_unsafe_load():
try:
from modules import safe
safe.load = safe.unsafe_torch_load
import torch
torch.load = safe.unsafe_torch_load
except:
pass
def print_xformers_torch1_instructions(xformers_version):
print(f"# Your version of xformers is {xformers_version}.")
print("# xformers >= 0.0.27 is required to be available on the Dreambooth tab.")
print("# Torch 1 wheels of xformers >= 0.0.27 are no longer available on PyPI,")
print("# but you can manually download them by going to:")
print("https://github.com/facebookresearch/xformers/actions")
print("# Click on the most recent action tagged with a release (middle column).")
print("# Select a download based on your environment.")
print("# Unzip your download")
print("# Activate your venv and install the wheel: (from A1111 project root)")
print("cd venv/Scripts")
print("activate")
print("pip install {REPLACE WITH PATH TO YOUR UNZIPPED .whl file}")
print("# Then restart your project.")