Update setup code

pull/2893/head
bmaltais 2024-10-05 08:21:14 -04:00
parent 7ab6efc5ca
commit 0d27feaf01
2 changed files with 452 additions and 442 deletions

View File

@ -1,363 +1,300 @@
import subprocess
import os
import re
import sys
import logging
import shutil
import datetime
import subprocess
import re
import pkg_resources
errors = 0 # Define the 'errors' variable before using it
log = logging.getLogger('sd')
log = logging.getLogger("sd")
# Constants
MIN_PYTHON_VERSION = (3, 10, 9)
MAX_PYTHON_VERSION = (3, 11, 0)
LOG_DIR = "../logs/setup/"
LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging
def check_python_version():
"""
Check if the current Python version is within the acceptable range.
Returns:
bool: True if the current Python version is valid, False otherwise.
bool: True if the current Python version is valid, False otherwise.
"""
min_version = (3, 10, 9)
max_version = (3, 11, 0)
from packaging import version
log.debug("Checking Python version...")
try:
current_version = sys.version_info
log.info(f"Python version is {sys.version}")
if not (min_version <= current_version < max_version):
log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI")
log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0")
if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION):
log.error(
f"The current version of python ({sys.version}) is not supported."
)
log.error("The Python version must be >= 3.10.9 and < 3.11.0.")
return False
return True
except Exception as e:
log.error(f"Failed to verify Python version. Error: {e}")
return False
def update_submodule(quiet=True):
"""
Ensure the submodule is initialized and updated.
This function uses the Git command line interface to initialize and update
the specified submodule recursively. Errors during the Git operation
or if Git is not found are caught and logged.
Parameters:
- quiet: If True, suppresses the output of the Git command.
"""
log.debug("Updating submodule...")
git_command = ["git", "submodule", "update", "--init", "--recursive"]
if quiet:
git_command.append("--quiet")
try:
# Initialize and update the submodule
subprocess.run(git_command, check=True)
log.info("Submodule initialized and updated.")
except subprocess.CalledProcessError as e:
# Log the error if the Git operation fails
log.error(f"Error during Git operation: {e}")
except FileNotFoundError as e:
# Log the error if the file is not found
log.error(e)
# def read_tag_version_from_file(file_path):
# """
# Read the tag version from a given file.
# Parameters:
# - file_path: The path to the file containing the tag version.
# Returns:
# The tag version as a string.
# """
# with open(file_path, 'r') as file:
# # Read the first line and strip whitespace
# tag_version = file.readline().strip()
# return tag_version
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
"""
Clone a repo or checkout a specific branch or tag if the repo already exists.
For branches, it updates to the latest version before checking out.
Suppresses detached HEAD advice for tags or specific commits.
Restores the original working directory after operations.
Parameters:
- repo_url: The URL of the Git repository.
- branch_or_tag: The name of the branch or tag to clone or checkout.
- directory_name: The name of the directory to clone into or where the repo already exists.
"""
original_dir = os.getcwd() # Store the original directory
log.debug(
f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}"
)
original_dir = os.getcwd()
try:
if not os.path.exists(directory_name):
# Directory does not exist, clone the repo quietly
# Construct the command as a string for logging
# run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}"
run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name]
# Log the command
log.debug(run_cmd)
# Run the command
process = subprocess.Popen(
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
output, error = process.communicate()
if error and not error.startswith("Note: switching to"):
log.warning(error)
else:
log.info(f"Successfully cloned sd-scripts {branch_or_tag}")
run_cmd = [
"git",
"clone",
"--branch",
branch_or_tag,
"--single-branch",
"--quiet",
repo_url,
directory_name,
]
log.debug(f"Cloning repository: {run_cmd}")
subprocess.run(run_cmd, check=True)
log.info(f"Successfully cloned {repo_url} ({branch_or_tag})")
else:
os.chdir(directory_name)
log.debug("Fetching all branches and tags...")
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True)
# Get the current branch or commit hash
current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode()
if current_branch_hash != tag_branch_hash:
run_cmd = f"git checkout {branch_or_tag} --quiet"
# Log the command
log.debug(run_cmd)
# Execute the checkout command
process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output, error = process.communicate()
if error:
log.warning(error.decode())
else:
log.info(f"Checked out sd-scripts {branch_or_tag} successfully.")
subprocess.run(
["git", "config", "advice.detachedHead", "false"], check=True
)
current_branch_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
)
target_branch_hash = (
subprocess.check_output(["git", "rev-parse", branch_or_tag])
.strip()
.decode()
)
if current_branch_hash != target_branch_hash:
log.debug(f"Checking out branch/tag: {branch_or_tag}")
subprocess.run(
["git", "checkout", branch_or_tag, "--quiet"], check=True
)
log.info(f"Checked out {branch_or_tag} successfully.")
else:
log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.")
log.info(f"Already at required branch/tag: {branch_or_tag}")
except subprocess.CalledProcessError as e:
log.error(f"Error during Git operation: {e}")
finally:
os.chdir(original_dir) # Restore the original directory
os.chdir(original_dir)
# setup console and file logging
def setup_logging(clean=False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
def setup_logging():
"""
Set up logging to file and console.
"""
log.debug("Setting up logging...")
from rich.theme import Theme
from rich.logging import RichHandler
from rich.console import Console
from rich.pretty import install as pretty_install
from rich.traceback import install as traceback_install
console = Console(
log_time=True,
log_time_format='%H:%M:%S-%f',
theme=Theme(
{
'traceback.border': 'black',
'traceback.border.syntax_error': 'black',
'inspect.value.border': 'black',
}
),
log_time_format="%H:%M:%S-%f",
theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}),
)
# logging.getLogger("urllib3").setLevel(logging.ERROR)
# logging.getLogger("httpx").setLevel(logging.ERROR)
current_datetime = datetime.datetime.now()
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = os.path.join(
os.path.dirname(__file__),
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log"
)
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# Create directories if they don't exist
log_directory = os.path.dirname(log_file)
os.makedirs(log_directory, exist_ok=True)
level = logging.INFO
logging.basicConfig(
level=logging.ERROR,
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s",
filename=log_file,
filemode='a',
encoding='utf-8',
filemode="a",
encoding="utf-8",
force=True,
)
log.setLevel(
logging.DEBUG
) # log to file is always at level debug for facility `sd`
pretty_install(console=console)
traceback_install(
console=console,
extra_lines=1,
width=console.width,
word_wrap=False,
indent_guides=False,
suppress=[],
)
rh = RichHandler(
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=False,
markup=False,
rich_tracebacks=True,
log_time_format='%H:%M:%S-%f',
level=level,
console=console,
)
rh.set_name(level)
while log.hasHandlers() and len(log.handlers) > 0:
log.removeHandler(log.handlers[0])
log.addHandler(rh)
log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper()
log.setLevel(getattr(logging, log_level, logging.DEBUG))
rich_handler = RichHandler(console=console)
# Replace existing handlers with the rich handler
log.handlers.clear()
log.addHandler(rich_handler)
log.debug("Logging setup complete.")
def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False):
def install_requirements_inbulk(
requirements_file, show_stdout=True, optional_parm="", upgrade=False
):
log.debug(f"Installing requirements in bulk from: {requirements_file}")
if not os.path.exists(requirements_file):
log.error(f'Could not find the requirements file in {requirements_file}.')
log.error(f"Could not find the requirements file in {requirements_file}.")
return
log.info(f'Installing requirements from {requirements_file}...')
log.info(f"Installing requirements from {requirements_file}...")
if upgrade:
optional_parm += " -U"
if show_stdout:
run_cmd(f'pip install -r {requirements_file} {optional_parm}')
run_cmd(f"pip install -r {requirements_file} {optional_parm}")
else:
run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet')
log.info(f'Requirements from {requirements_file} installed.')
run_cmd(f"pip install -r {requirements_file} {optional_parm} --quiet")
log.info(f"Requirements from {requirements_file} installed.")
def configure_accelerate(run_accelerate=False):
#
# This function was taken and adapted from code written by jstayco
#
log.debug("Configuring accelerate...")
from pathlib import Path
def env_var_exists(var_name):
return var_name in os.environ and os.environ[var_name] != ''
return var_name in os.environ and os.environ[var_name] != ""
log.info("Configuring accelerate...")
log.info('Configuring accelerate...')
source_accelerate_config_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..',
'config_files',
'accelerate',
'default_config.yaml',
"..",
"config_files",
"accelerate",
"default_config.yaml",
)
if not os.path.exists(source_accelerate_config_file):
log.warning(
f"Could not find the accelerate configuration file in {source_accelerate_config_file}."
)
if run_accelerate:
run_cmd('accelerate config')
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
"Please configure accelerate manually by running the option in the menu."
)
log.debug(
f'Source accelerate config location: {source_accelerate_config_file}'
)
return
log.debug(f"Source accelerate config location: {source_accelerate_config_file}")
target_config_location = None
log.debug(
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
)
if env_var_exists('HF_HOME'):
target_config_location = Path(
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
)
elif env_var_exists('LOCALAPPDATA'):
target_config_location = Path(
os.environ['LOCALAPPDATA'],
'huggingface',
'accelerate',
'default_config.yaml',
)
elif env_var_exists('USERPROFILE'):
target_config_location = Path(
os.environ['USERPROFILE'],
'.cache',
'huggingface',
'accelerate',
'default_config.yaml',
)
env_vars = {
"HF_HOME": Path(os.environ.get("HF_HOME", "")),
"LOCALAPPDATA": Path(
os.environ.get("LOCALAPPDATA", ""),
"huggingface",
"accelerate",
"default_config.yaml",
),
"USERPROFILE": Path(
os.environ.get("USERPROFILE", ""),
".cache",
"huggingface",
"accelerate",
"default_config.yaml",
),
}
log.debug(f'Target config location: {target_config_location}')
for var, path in env_vars.items():
if env_var_exists(var):
target_config_location = path
break
log.debug(f"Target config location: {target_config_location}")
if target_config_location:
if not target_config_location.is_file():
log.debug(
f"Creating target config directory: {target_config_location.parent}"
)
target_config_location.parent.mkdir(parents=True, exist_ok=True)
log.debug(
f'Target accelerate config location: {target_config_location}'
f"Copying config file to target location: {target_config_location}"
)
shutil.copyfile(
source_accelerate_config_file, target_config_location
)
log.info(
f'Copied accelerate config file to: {target_config_location}'
)
else:
if run_accelerate:
run_cmd('accelerate config')
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
)
else:
if run_accelerate:
run_cmd('accelerate config')
shutil.copyfile(source_accelerate_config_file, target_config_location)
log.info(f"Copied accelerate config file to: {target_config_location}")
elif run_accelerate:
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
)
elif run_accelerate:
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
)
def check_torch():
log.debug("Checking Torch installation...")
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
# This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
#
# Check for toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
if shutil.which("nvidia-smi") is not None or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
os.environ.get("SystemRoot") or r"C:\Windows",
"System32",
"nvidia-smi.exe",
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
log.info("nVidia toolkit detected")
elif shutil.which("rocminfo") is not None or os.path.exists(
"/opt/rocm/bin/rocminfo"
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
log.info("AMD toolkit detected")
elif (
shutil.which("sycl-ls") is not None
or os.environ.get("ONEAPI_ROOT") is not None
or os.path.exists("/opt/intel/oneapi")
):
log.info("Intel OneAPI toolkit detected")
else:
log.info('Using CPU-only Torch')
log.info("Using CPU-only Torch")
try:
import torch
log.debug("Torch module imported successfully.")
try:
# Import IPEX / XPU support
import intel_extension_for_pytorch as ipex
except Exception:
pass
log.info(f'Torch {torch.__version__}')
log.debug("Intel extension for PyTorch imported successfully.")
except Exception as e:
log.warning(f"Failed to import intel_extension_for_pytorch: {e}")
log.info(f"Torch {torch.__version__}")
if torch.cuda.is_available():
if torch.version.cuda:
@ -367,33 +304,33 @@ def check_torch():
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
else:
log.warning('Unknown Torch backend')
log.warning("Unknown Torch backend")
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}"
)
# Check if XPU is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
for device in [
torch.xpu.device(i) for i in range(torch.xpu.device_count())
]:
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}"
)
else:
log.warning('Torch reports GPU not available')
log.warning("Torch reports GPU not available")
return int(torch.__version__[0])
except Exception as e:
# log.warning(f'Could not load torch: {e}')
log.error(f"Could not load torch: {e}")
return 0
@ -404,17 +341,19 @@ def check_repo_version():
in the current directory. If the file exists, it reads the release version from the file and logs it.
If the file does not exist, it logs a debug message indicating that the release could not be read.
"""
if os.path.exists('.release'):
log.debug("Checking repository version...")
if os.path.exists(".release"):
try:
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release= file.read()
log.info(f'Kohya_ss GUI version: {release}')
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
release = file.read()
log.info(f"Kohya_ss GUI version: {release}")
except Exception as e:
log.error(f'Could not read release: {e}')
log.error(f"Could not read release: {e}")
else:
log.debug('Could not read release...')
log.debug("Could not read release...")
# execute git command
def git(arg: str, folder: str = None, ignore: bool = False):
"""
@ -433,22 +372,31 @@ def git(arg: str, folder: str = None, ignore: bool = False):
If set to True, errors will not be logged.
Note:
This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
"""
# git_cmd = os.environ.get('GIT', "git")
result = subprocess.run(["git", arg], check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
log.debug(f"Running git command: git {arg} in folder: {folder or '.'}")
result = subprocess.run(
["git", arg],
check=False,
shell=True,
env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=folder or ".",
)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
encoding="utf8", errors="ignore"
)
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors
errors += 1
log.error(f'Error running git: {folder} / {arg}')
if 'or stash them' in txt:
log.error(f'Local changes detected: check log for details...')
log.debug(f'Git output: {txt}')
log.error(f"Error running git: {folder} / {arg}")
if "or stash them" in txt:
log.error(f"Local changes detected: check log for details...")
log.debug(f"Git output: {txt}")
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
@ -473,32 +421,44 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool =
Returns:
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
"""
# arg = arg.replace('>=', '==')
log.debug(f"Running pip command: {arg}")
if not quiet:
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
pip_cmd = [fr"{sys.executable}", "-m", "pip"] + arg.split(" ")
log.info(
f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}'
)
pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ")
log.debug(f"Running pip: {pip_cmd}")
if show_stdout:
subprocess.run(pip_cmd, shell=False, check=False, env=os.environ)
else:
result = subprocess.run(pip_cmd, shell=False, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run(
pip_cmd,
shell=False,
check=False,
env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
encoding="utf8", errors="ignore"
)
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors # pylint: disable=global-statement
global errors # pylint: disable=global-statement
errors += 1
log.error(f'Error running pip: {arg}')
log.error(f'Pip output: {txt}')
log.error(f"Error running pip: {arg}")
log.error(f"Pip output: {txt}")
return txt
def installed(package, friendly: str = None):
"""
Checks if the specified package(s) are installed with the correct version.
This function can handle package specifications with or without version constraints,
and can also filter out command-line options and URLs when a 'friendly' string is provided.
Parameters:
- package: A string that specifies one or more packages with optional version constraints.
- friendly: An optional string used to provide a cleaner version of the package string
@ -506,43 +466,39 @@ def installed(package, friendly: str = None):
Returns:
- True if all specified packages are installed with the correct versions, False otherwise.
Note:
This function was adapted from code written by vladimandic.
"""
log.debug(f"Checking if package is installed: {package}")
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
package = re.sub(r'\[.*?\]', '', package)
package = re.sub(r"\[.*?\]", "", package)
try:
if friendly:
# If a 'friendly' version of the package string is provided, split it into components
pkgs = friendly.split()
# Filter out command-line options and URLs from the package specification
pkgs = [
p
for p in package.split()
if not p.startswith('--') and "://" not in p
p for p in package.split() if not p.startswith("--") and "://" not in p
]
else:
# Split the package string into components, excluding '-' and '=' prefixed items
pkgs = [
p
for p in package.split()
if not p.startswith('-') and not p.startswith('=')
if not p.startswith("-") and not p.startswith("=")
]
# For each package component, extract the package name, excluding any URLs
pkgs = [
p.split('/')[-1] for p in pkgs
]
pkgs = [p.split("/")[-1] for p in pkgs]
for pkg in pkgs:
# Parse the package name and version based on the version specifier used
if '>=' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
elif '==' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
if ">=" in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")]
elif "==" in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split("==")]
else:
pkg_name, pkg_version = pkg.strip(), None
@ -553,38 +509,41 @@ def installed(package, friendly: str = None):
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
if spec is None:
# Try replacing underscores with dashes
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
spec = pkg_resources.working_set.by_key.get(
pkg_name.replace("_", "-"), None
)
if spec is not None:
# Package is found, check version
version = pkg_resources.get_distribution(pkg_name).version
log.debug(f'Package version found: {pkg_name} {version}')
log.debug(f"Package version found: {pkg_name} {version}")
if pkg_version is not None:
# Verify if the installed version meets the specified constraints
if '>=' in pkg:
if ">=" in pkg:
ok = version >= pkg_version
else:
ok = version == pkg_version
if not ok:
# Version mismatch, log warning and return False
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
log.warning(
f"Package wrong version: {pkg_name} {version} required {pkg_version}"
)
return False
else:
# Package not found, log debug message and return False
log.debug(f'Package version not found: {pkg_name}')
log.debug(f"Package version not found: {pkg_name}")
return False
# All specified packages are installed with the correct versions
return True
except ModuleNotFoundError:
# One or more packages are not installed, log debug message and return False
log.debug(f'Package not installed: {pkgs}')
log.debug(f"Package not installed: {pkgs}")
return False
# install package using pip if not already installed
def install(
package,
@ -596,7 +555,7 @@ def install(
"""
Installs or upgrades a Python package using pip, with options to ignode errors,
reinstall packages, and display outputs.
Parameters:
- package (str): The name of the package to be installed or upgraded. Can include
version specifiers. Anything after a '#' in the package name will be ignored.
@ -612,103 +571,98 @@ def install(
Returns:
None. The function performs operations that affect the environment but does not return
any value.
Note:
If `reinstall` is True, it disables any mechanism that allows for skipping installations
when the package is already present, forcing a fresh install.
"""
log.debug(f"Installing package: {package}")
# Remove anything after '#' in the package variable
package = package.split('#')[0].strip()
package = package.split("#")[0].strip()
if reinstall:
global quick_allowed # pylint: disable=global-statement
global quick_allowed # pylint: disable=global-statement
quick_allowed = False
if reinstall or not installed(package, friendly):
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout)
def process_requirements_line(line, show_stdout: bool = False):
log.debug(f"Processing requirements line: {line}")
# Remove brackets and their contents from the line using regular expressions
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
package_name = re.sub(r'\[.*?\]', '', line)
package_name = re.sub(r"\[.*?\]", "", line)
install(line, package_name, show_stdout=show_stdout)
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
if check_no_verify_flag:
log.info(f'Verifying modules installation status from {requirements_file}...')
else:
log.info(f'Installing modules from {requirements_file}...')
with open(requirements_file, 'r', encoding='utf8') as f:
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
if check_no_verify_flag:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
and 'no_verify' not in line
]
else:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
]
def install_requirements(
requirements_file, check_no_verify_flag=False, show_stdout: bool = False
):
"""
Install or verify modules from a requirements file.
# Iterate over each line and install the requirements
for line in lines:
# Check if the line starts with '-r' to include another requirements file
if line.startswith('-r'):
# Get the path to the included requirements file
included_file = line[2:].strip()
# Expand the included requirements file recursively
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
else:
process_requirements_line(line, show_stdout=show_stdout)
Parameters:
- requirements_file (str): Path to the requirements file.
- check_no_verify_flag (bool): If True, verify modules installation status without installing.
- show_stdout (bool): If True, show the standard output of the installation process.
"""
log.debug(f"Installing requirements from file: {requirements_file}")
action = "Verifying" if check_no_verify_flag else "Installing"
log.info(f"{action} modules from {requirements_file}...")
with open(requirements_file, "r", encoding="utf8") as f:
lines = [
line.strip()
for line in f.readlines()
if line.strip() and not line.startswith("#") and "no_verify" not in line
]
for line in lines:
if line.startswith("-r"):
included_file = line[2:].strip()
log.debug(f"Processing included requirements file: {included_file}")
install_requirements(
included_file,
check_no_verify_flag=check_no_verify_flag,
show_stdout=show_stdout,
)
else:
process_requirements_line(line, show_stdout=show_stdout)
def ensure_base_requirements():
try:
import rich # pylint: disable=unused-import
import rich # pylint: disable=unused-import
except ImportError:
install('--upgrade rich', 'rich')
install("--upgrade rich", "rich")
try:
import packaging
except ImportError:
install('packaging')
install("packaging")
def run_cmd(run_cmd):
"""
Execute a command using subprocess.
"""
log.debug(f"Running command: {run_cmd}")
try:
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
subprocess.run(run_cmd, shell=True, check=True, env=os.environ)
log.info(f"Command executed successfully: {run_cmd}")
except subprocess.CalledProcessError as e:
log.error(f'Error occurred while running command: {run_cmd}')
log.error(f'Error: {e}')
def delete_file(file_path):
if os.path.exists(file_path):
os.remove(file_path)
def write_to_file(file_path, content):
try:
with open(file_path, 'w') as file:
file.write(content)
except IOError as e:
print(f'Error occurred while writing to file: {file_path}')
print(f'Error: {e}')
log.error(f"Error occurred while running command: {run_cmd}")
log.error(f"Error: {e}")
def clear_screen():
# Check the current operating system to execute the correct clear screen command
if os.name == 'nt': # If the operating system is Windows
os.system('cls')
else: # If the operating system is Linux or Mac
os.system('clear')
"""
Clear the terminal screen.
"""
log.debug("Attempting to clear the terminal screen")
try:
os.system("cls" if os.name == "nt" else "clear")
log.info("Terminal screen cleared successfully")
except Exception as e:
log.error("Error occurred while clearing the terminal screen")
log.error(f"Error: {e}")

View File

@ -5,12 +5,11 @@ import argparse
import setup_common
# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = os.path.dirname(os.path.abspath(__file__))
# Check if the "setup" directory is present in the project_directory
if "setup" in project_directory:
# If the "setup" directory is present, move one level up to the parent directory
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
project_directory = (
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if "setup" in os.path.dirname(os.path.abspath(__file__))
else os.path.dirname(os.path.abspath(__file__))
)
# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)
@ -19,115 +18,172 @@ from kohya_gui.custom_logging import setup_logging
# Set up logging
log = setup_logging()
log.debug(f"Project directory set to: {project_directory}")
def check_path_with_space():
# Get the current working directory
"""Check if the current working directory contains a space."""
cwd = os.getcwd()
# Check if the current working directory contains a space
log.debug(f"Current working directory: {cwd}")
if " " in cwd:
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
log.error("The current working directory is: " + cwd)
exit(1)
# Log an error if the current working directory contains spaces
log.error(
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
)
log.error(
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
)
log.error(f"The current working directory is: {cwd}")
raise RuntimeError("Invalid path: contains spaces.")
def check_torch():
# Check for toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
def detect_toolkit():
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
log.debug("Detecting available toolkit...")
# Check for NVIDIA toolkit by looking for nvidia-smi executable
if shutil.which("nvidia-smi") or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
log.debug("nVidia toolkit detected")
return "nVidia"
# Check for AMD toolkit by looking for rocminfo executable
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
log.debug("AMD toolkit detected")
return "AMD"
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
elif (
shutil.which("sycl-ls")
or os.environ.get("ONEAPI_ROOT")
or os.path.exists("/opt/intel/oneapi")
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
log.debug("Intel toolkit detected")
return "Intel"
# Default to CPU if no toolkit is detected
else:
log.info('Using CPU-only Torch')
log.debug("No specific GPU toolkit detected, defaulting to CPU")
return "CPU"
def check_torch():
"""Check if torch is available and log the relevant information."""
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
toolkit = detect_toolkit()
log.info(f"{toolkit} toolkit detected")
try:
# Import PyTorch
log.debug("Importing PyTorch...")
import torch
try:
# Import IPEX / XPU support
import intel_extension_for_pytorch as ipex
except Exception:
pass
log.info(f'Torch {torch.__version__}')
ipex = None
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
if toolkit == "Intel":
try:
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
import intel_extension_for_pytorch as ipex
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
except ImportError:
log.warning("Intel Extension for PyTorch (IPEX) not found.")
# Log the PyTorch version
log.info(f"Torch {torch.__version__}")
# Check if CUDA (NVIDIA GPU) is available
if torch.cuda.is_available():
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
else:
log.warning('Unknown Torch backend')
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
# Check if XPU is available
log.debug("CUDA is available, logging CUDA info...")
log_cuda_info(torch)
# Check if XPU (Intel GPU) is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
for device in [
torch.xpu.device(i) for i in range(torch.xpu.device_count())
]:
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
)
log.debug("XPU is available, logging XPU info...")
log_xpu_info(torch, ipex)
# Log a warning if no GPU is available
else:
log.warning('Torch reports GPU not available')
log.warning("Torch reports GPU not available")
# Return the major version of PyTorch
return int(torch.__version__[0])
except Exception as e:
log.error(f'Could not load torch: {e}')
except ImportError as e:
# Log an error if PyTorch cannot be loaded
log.error(f"Could not load torch: {e}")
sys.exit(1)
except Exception as e:
# Log an unexpected error
log.error(f"Unexpected error while checking torch: {e}")
sys.exit(1)
def log_cuda_info(torch):
"""Log information about CUDA-enabled GPUs."""
# Log the CUDA and cuDNN versions if available
if torch.version.cuda:
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
# Log the ROCm HIP version if using AMD GPU
elif torch.version.hip:
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
else:
log.warning("Unknown Torch backend")
# Log information about each detected CUDA-enabled GPU
for device in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
)
def log_xpu_info(torch, ipex):
"""Log information about Intel XPU-enabled GPUs."""
# Log the Intel Extension for PyTorch (IPEX) version if available
if ipex:
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
# Log information about each detected XPU-enabled GPU
for device in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
)
def main():
# Check the repository version to ensure compatibility
log.debug("Checking repository version...")
setup_common.check_repo_version()
# Check if the current path contains spaces, which are not supported
log.debug("Checking if the current path contains spaces...")
check_path_with_space()
# Parse command line arguments
log.debug("Parsing command line arguments...")
parser = argparse.ArgumentParser(
description='Validate that requirements are satisfied.'
description="Validate that requirements are satisfied."
)
parser.add_argument(
'-r',
'--requirements',
type=str,
help='Path to the requirements file.',
"-r", "--requirements", type=str, help="Path to the requirements file."
)
parser.add_argument('--debug', action='store_true', help='Debug on')
parser.add_argument("--debug", action="store_true", help="Debug on")
args = parser.parse_args()
# Update git submodules if necessary
log.debug("Updating git submodules...")
setup_common.update_submodule()
# Check if PyTorch is installed and log relevant information
log.debug("Checking if PyTorch is installed...")
torch_ver = check_torch()
if not setup_common.check_python_version():
exit(1)
if args.requirements:
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
else:
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
if __name__ == '__main__':
# Check if the Python version is compatible
log.debug("Checking Python version...")
if not setup_common.check_python_version():
sys.exit(1)
# Install required packages from the specified requirements file
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
log.debug(f"Installing requirements from: {requirements_file}")
setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
log.debug("Installing additional requirements from: requirements_windows.txt")
setup_common.install_requirements(
"requirements_windows.txt", check_no_verify_flag=True
)
if __name__ == "__main__":
log.debug("Starting main function...")
main()
log.debug("Main function finished.")