mirror of https://github.com/bmaltais/kohya_ss
Update setup code
parent
7ab6efc5ca
commit
0d27feaf01
|
|
@ -1,363 +1,300 @@
|
|||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import re
|
||||
import pkg_resources
|
||||
|
||||
errors = 0 # Define the 'errors' variable before using it
|
||||
log = logging.getLogger('sd')
|
||||
log = logging.getLogger("sd")
|
||||
|
||||
# Constants
|
||||
MIN_PYTHON_VERSION = (3, 10, 9)
|
||||
MAX_PYTHON_VERSION = (3, 11, 0)
|
||||
LOG_DIR = "../logs/setup/"
|
||||
LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""
|
||||
Check if the current Python version is within the acceptable range.
|
||||
|
||||
Returns:
|
||||
bool: True if the current Python version is valid, False otherwise.
|
||||
bool: True if the current Python version is valid, False otherwise.
|
||||
"""
|
||||
min_version = (3, 10, 9)
|
||||
max_version = (3, 11, 0)
|
||||
|
||||
from packaging import version
|
||||
|
||||
log.debug("Checking Python version...")
|
||||
try:
|
||||
current_version = sys.version_info
|
||||
log.info(f"Python version is {sys.version}")
|
||||
|
||||
if not (min_version <= current_version < max_version):
|
||||
log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI")
|
||||
log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0")
|
||||
|
||||
if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION):
|
||||
log.error(
|
||||
f"The current version of python ({sys.version}) is not supported."
|
||||
)
|
||||
log.error("The Python version must be >= 3.10.9 and < 3.11.0.")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Failed to verify Python version. Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def update_submodule(quiet=True):
|
||||
"""
|
||||
Ensure the submodule is initialized and updated.
|
||||
|
||||
This function uses the Git command line interface to initialize and update
|
||||
the specified submodule recursively. Errors during the Git operation
|
||||
or if Git is not found are caught and logged.
|
||||
|
||||
Parameters:
|
||||
- quiet: If True, suppresses the output of the Git command.
|
||||
"""
|
||||
log.debug("Updating submodule...")
|
||||
git_command = ["git", "submodule", "update", "--init", "--recursive"]
|
||||
|
||||
if quiet:
|
||||
git_command.append("--quiet")
|
||||
|
||||
|
||||
try:
|
||||
# Initialize and update the submodule
|
||||
subprocess.run(git_command, check=True)
|
||||
log.info("Submodule initialized and updated.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Log the error if the Git operation fails
|
||||
log.error(f"Error during Git operation: {e}")
|
||||
except FileNotFoundError as e:
|
||||
# Log the error if the file is not found
|
||||
log.error(e)
|
||||
|
||||
# def read_tag_version_from_file(file_path):
|
||||
# """
|
||||
# Read the tag version from a given file.
|
||||
|
||||
# Parameters:
|
||||
# - file_path: The path to the file containing the tag version.
|
||||
|
||||
# Returns:
|
||||
# The tag version as a string.
|
||||
# """
|
||||
# with open(file_path, 'r') as file:
|
||||
# # Read the first line and strip whitespace
|
||||
# tag_version = file.readline().strip()
|
||||
# return tag_version
|
||||
|
||||
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
|
||||
"""
|
||||
Clone a repo or checkout a specific branch or tag if the repo already exists.
|
||||
For branches, it updates to the latest version before checking out.
|
||||
Suppresses detached HEAD advice for tags or specific commits.
|
||||
Restores the original working directory after operations.
|
||||
|
||||
Parameters:
|
||||
- repo_url: The URL of the Git repository.
|
||||
- branch_or_tag: The name of the branch or tag to clone or checkout.
|
||||
- directory_name: The name of the directory to clone into or where the repo already exists.
|
||||
"""
|
||||
original_dir = os.getcwd() # Store the original directory
|
||||
log.debug(
|
||||
f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
try:
|
||||
if not os.path.exists(directory_name):
|
||||
# Directory does not exist, clone the repo quietly
|
||||
|
||||
# Construct the command as a string for logging
|
||||
# run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}"
|
||||
run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name]
|
||||
|
||||
|
||||
# Log the command
|
||||
log.debug(run_cmd)
|
||||
|
||||
# Run the command
|
||||
process = subprocess.Popen(
|
||||
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
output, error = process.communicate()
|
||||
|
||||
if error and not error.startswith("Note: switching to"):
|
||||
log.warning(error)
|
||||
else:
|
||||
log.info(f"Successfully cloned sd-scripts {branch_or_tag}")
|
||||
|
||||
run_cmd = [
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch_or_tag,
|
||||
"--single-branch",
|
||||
"--quiet",
|
||||
repo_url,
|
||||
directory_name,
|
||||
]
|
||||
log.debug(f"Cloning repository: {run_cmd}")
|
||||
subprocess.run(run_cmd, check=True)
|
||||
log.info(f"Successfully cloned {repo_url} ({branch_or_tag})")
|
||||
else:
|
||||
os.chdir(directory_name)
|
||||
log.debug("Fetching all branches and tags...")
|
||||
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
|
||||
subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True)
|
||||
|
||||
# Get the current branch or commit hash
|
||||
current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode()
|
||||
|
||||
if current_branch_hash != tag_branch_hash:
|
||||
run_cmd = f"git checkout {branch_or_tag} --quiet"
|
||||
# Log the command
|
||||
log.debug(run_cmd)
|
||||
|
||||
# Execute the checkout command
|
||||
process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
output, error = process.communicate()
|
||||
|
||||
if error:
|
||||
log.warning(error.decode())
|
||||
else:
|
||||
log.info(f"Checked out sd-scripts {branch_or_tag} successfully.")
|
||||
subprocess.run(
|
||||
["git", "config", "advice.detachedHead", "false"], check=True
|
||||
)
|
||||
|
||||
current_branch_hash = (
|
||||
subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
)
|
||||
target_branch_hash = (
|
||||
subprocess.check_output(["git", "rev-parse", branch_or_tag])
|
||||
.strip()
|
||||
.decode()
|
||||
)
|
||||
|
||||
if current_branch_hash != target_branch_hash:
|
||||
log.debug(f"Checking out branch/tag: {branch_or_tag}")
|
||||
subprocess.run(
|
||||
["git", "checkout", branch_or_tag, "--quiet"], check=True
|
||||
)
|
||||
log.info(f"Checked out {branch_or_tag} successfully.")
|
||||
else:
|
||||
log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.")
|
||||
log.info(f"Already at required branch/tag: {branch_or_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f"Error during Git operation: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir) # Restore the original directory
|
||||
os.chdir(original_dir)
|
||||
|
||||
# setup console and file logging
|
||||
def setup_logging(clean=False):
|
||||
#
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
#
|
||||
|
||||
def setup_logging():
|
||||
"""
|
||||
Set up logging to file and console.
|
||||
"""
|
||||
log.debug("Setting up logging...")
|
||||
|
||||
from rich.theme import Theme
|
||||
from rich.logging import RichHandler
|
||||
from rich.console import Console
|
||||
from rich.pretty import install as pretty_install
|
||||
from rich.traceback import install as traceback_install
|
||||
|
||||
console = Console(
|
||||
log_time=True,
|
||||
log_time_format='%H:%M:%S-%f',
|
||||
theme=Theme(
|
||||
{
|
||||
'traceback.border': 'black',
|
||||
'traceback.border.syntax_error': 'black',
|
||||
'inspect.value.border': 'black',
|
||||
}
|
||||
),
|
||||
log_time_format="%H:%M:%S-%f",
|
||||
theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}),
|
||||
)
|
||||
# logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
# logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
|
||||
current_datetime = datetime.datetime.now()
|
||||
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
|
||||
current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
log_file = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
|
||||
os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log"
|
||||
)
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
|
||||
# Create directories if they don't exist
|
||||
log_directory = os.path.dirname(log_file)
|
||||
os.makedirs(log_directory, exist_ok=True)
|
||||
|
||||
level = logging.INFO
|
||||
logging.basicConfig(
|
||||
level=logging.ERROR,
|
||||
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
|
||||
format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s",
|
||||
filename=log_file,
|
||||
filemode='a',
|
||||
encoding='utf-8',
|
||||
filemode="a",
|
||||
encoding="utf-8",
|
||||
force=True,
|
||||
)
|
||||
log.setLevel(
|
||||
logging.DEBUG
|
||||
) # log to file is always at level debug for facility `sd`
|
||||
pretty_install(console=console)
|
||||
traceback_install(
|
||||
console=console,
|
||||
extra_lines=1,
|
||||
width=console.width,
|
||||
word_wrap=False,
|
||||
indent_guides=False,
|
||||
suppress=[],
|
||||
)
|
||||
rh = RichHandler(
|
||||
show_time=True,
|
||||
omit_repeated_times=False,
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
markup=False,
|
||||
rich_tracebacks=True,
|
||||
log_time_format='%H:%M:%S-%f',
|
||||
level=level,
|
||||
console=console,
|
||||
)
|
||||
rh.set_name(level)
|
||||
while log.hasHandlers() and len(log.handlers) > 0:
|
||||
log.removeHandler(log.handlers[0])
|
||||
log.addHandler(rh)
|
||||
log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper()
|
||||
log.setLevel(getattr(logging, log_level, logging.DEBUG))
|
||||
rich_handler = RichHandler(console=console)
|
||||
|
||||
# Replace existing handlers with the rich handler
|
||||
log.handlers.clear()
|
||||
log.addHandler(rich_handler)
|
||||
log.debug("Logging setup complete.")
|
||||
|
||||
|
||||
def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False):
|
||||
def install_requirements_inbulk(
|
||||
requirements_file, show_stdout=True, optional_parm="", upgrade=False
|
||||
):
|
||||
log.debug(f"Installing requirements in bulk from: {requirements_file}")
|
||||
if not os.path.exists(requirements_file):
|
||||
log.error(f'Could not find the requirements file in {requirements_file}.')
|
||||
log.error(f"Could not find the requirements file in {requirements_file}.")
|
||||
return
|
||||
|
||||
log.info(f'Installing requirements from {requirements_file}...')
|
||||
log.info(f"Installing requirements from {requirements_file}...")
|
||||
|
||||
if upgrade:
|
||||
optional_parm += " -U"
|
||||
|
||||
if show_stdout:
|
||||
run_cmd(f'pip install -r {requirements_file} {optional_parm}')
|
||||
run_cmd(f"pip install -r {requirements_file} {optional_parm}")
|
||||
else:
|
||||
run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet')
|
||||
log.info(f'Requirements from {requirements_file} installed.')
|
||||
|
||||
run_cmd(f"pip install -r {requirements_file} {optional_parm} --quiet")
|
||||
log.info(f"Requirements from {requirements_file} installed.")
|
||||
|
||||
|
||||
def configure_accelerate(run_accelerate=False):
|
||||
#
|
||||
# This function was taken and adapted from code written by jstayco
|
||||
#
|
||||
|
||||
log.debug("Configuring accelerate...")
|
||||
from pathlib import Path
|
||||
|
||||
def env_var_exists(var_name):
|
||||
return var_name in os.environ and os.environ[var_name] != ''
|
||||
return var_name in os.environ and os.environ[var_name] != ""
|
||||
|
||||
log.info("Configuring accelerate...")
|
||||
|
||||
log.info('Configuring accelerate...')
|
||||
|
||||
source_accelerate_config_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'..',
|
||||
'config_files',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
"..",
|
||||
"config_files",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
)
|
||||
|
||||
if not os.path.exists(source_accelerate_config_file):
|
||||
log.warning(
|
||||
f"Could not find the accelerate configuration file in {source_accelerate_config_file}."
|
||||
)
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
|
||||
"Please configure accelerate manually by running the option in the menu."
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f'Source accelerate config location: {source_accelerate_config_file}'
|
||||
)
|
||||
return
|
||||
|
||||
log.debug(f"Source accelerate config location: {source_accelerate_config_file}")
|
||||
|
||||
target_config_location = None
|
||||
|
||||
log.debug(
|
||||
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
|
||||
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
|
||||
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
|
||||
)
|
||||
if env_var_exists('HF_HOME'):
|
||||
target_config_location = Path(
|
||||
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
|
||||
)
|
||||
elif env_var_exists('LOCALAPPDATA'):
|
||||
target_config_location = Path(
|
||||
os.environ['LOCALAPPDATA'],
|
||||
'huggingface',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
)
|
||||
elif env_var_exists('USERPROFILE'):
|
||||
target_config_location = Path(
|
||||
os.environ['USERPROFILE'],
|
||||
'.cache',
|
||||
'huggingface',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
)
|
||||
env_vars = {
|
||||
"HF_HOME": Path(os.environ.get("HF_HOME", "")),
|
||||
"LOCALAPPDATA": Path(
|
||||
os.environ.get("LOCALAPPDATA", ""),
|
||||
"huggingface",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
),
|
||||
"USERPROFILE": Path(
|
||||
os.environ.get("USERPROFILE", ""),
|
||||
".cache",
|
||||
"huggingface",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
),
|
||||
}
|
||||
|
||||
log.debug(f'Target config location: {target_config_location}')
|
||||
for var, path in env_vars.items():
|
||||
if env_var_exists(var):
|
||||
target_config_location = path
|
||||
break
|
||||
|
||||
log.debug(f"Target config location: {target_config_location}")
|
||||
|
||||
if target_config_location:
|
||||
if not target_config_location.is_file():
|
||||
log.debug(
|
||||
f"Creating target config directory: {target_config_location.parent}"
|
||||
)
|
||||
target_config_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
log.debug(
|
||||
f'Target accelerate config location: {target_config_location}'
|
||||
f"Copying config file to target location: {target_config_location}"
|
||||
)
|
||||
shutil.copyfile(
|
||||
source_accelerate_config_file, target_config_location
|
||||
)
|
||||
log.info(
|
||||
f'Copied accelerate config file to: {target_config_location}'
|
||||
)
|
||||
else:
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
else:
|
||||
log.warning(
|
||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
||||
)
|
||||
else:
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
shutil.copyfile(source_accelerate_config_file, target_config_location)
|
||||
log.info(f"Copied accelerate config file to: {target_config_location}")
|
||||
elif run_accelerate:
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
||||
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||
)
|
||||
elif run_accelerate:
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||
)
|
||||
|
||||
|
||||
def check_torch():
|
||||
log.debug("Checking Torch installation...")
|
||||
#
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||
#
|
||||
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
if shutil.which("nvidia-smi") is not None or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
'System32',
|
||||
'nvidia-smi.exe',
|
||||
os.environ.get("SystemRoot") or r"C:\Windows",
|
||||
"System32",
|
||||
"nvidia-smi.exe",
|
||||
)
|
||||
):
|
||||
log.info('nVidia toolkit detected')
|
||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
||||
'/opt/rocm/bin/rocminfo'
|
||||
log.info("nVidia toolkit detected")
|
||||
elif shutil.which("rocminfo") is not None or os.path.exists(
|
||||
"/opt/rocm/bin/rocminfo"
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
log.info("AMD toolkit detected")
|
||||
elif (
|
||||
shutil.which("sycl-ls") is not None
|
||||
or os.environ.get("ONEAPI_ROOT") is not None
|
||||
or os.path.exists("/opt/intel/oneapi")
|
||||
):
|
||||
log.info("Intel OneAPI toolkit detected")
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
log.info("Using CPU-only Torch")
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
log.debug("Torch module imported successfully.")
|
||||
try:
|
||||
# Import IPEX / XPU support
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
log.debug("Intel extension for PyTorch imported successfully.")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to import intel_extension_for_pytorch: {e}")
|
||||
log.info(f"Torch {torch.__version__}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda:
|
||||
|
|
@ -367,33 +304,33 @@ def check_torch():
|
|||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
log.warning("Unknown Torch backend")
|
||||
|
||||
# Log information about detected GPUs
|
||||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}"
|
||||
)
|
||||
# Check if XPU is available
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||
for device in [
|
||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}"
|
||||
)
|
||||
else:
|
||||
log.warning('Torch reports GPU not available')
|
||||
|
||||
log.warning("Torch reports GPU not available")
|
||||
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
# log.warning(f'Could not load torch: {e}')
|
||||
log.error(f"Could not load torch: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -404,17 +341,19 @@ def check_repo_version():
|
|||
in the current directory. If the file exists, it reads the release version from the file and logs it.
|
||||
If the file does not exist, it logs a debug message indicating that the release could not be read.
|
||||
"""
|
||||
if os.path.exists('.release'):
|
||||
log.debug("Checking repository version...")
|
||||
if os.path.exists(".release"):
|
||||
try:
|
||||
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
|
||||
release= file.read()
|
||||
|
||||
log.info(f'Kohya_ss GUI version: {release}')
|
||||
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
|
||||
release = file.read()
|
||||
|
||||
log.info(f"Kohya_ss GUI version: {release}")
|
||||
except Exception as e:
|
||||
log.error(f'Could not read release: {e}')
|
||||
log.error(f"Could not read release: {e}")
|
||||
else:
|
||||
log.debug('Could not read release...')
|
||||
|
||||
log.debug("Could not read release...")
|
||||
|
||||
|
||||
# execute git command
|
||||
def git(arg: str, folder: str = None, ignore: bool = False):
|
||||
"""
|
||||
|
|
@ -433,22 +372,31 @@ def git(arg: str, folder: str = None, ignore: bool = False):
|
|||
If set to True, errors will not be logged.
|
||||
|
||||
Note:
|
||||
This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||
"""
|
||||
|
||||
# git_cmd = os.environ.get('GIT', "git")
|
||||
result = subprocess.run(["git", arg], check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
|
||||
log.debug(f"Running git command: git {arg} in folder: {folder or '.'}")
|
||||
result = subprocess.run(
|
||||
["git", arg],
|
||||
check=False,
|
||||
shell=True,
|
||||
env=os.environ,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=folder or ".",
|
||||
)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||
encoding="utf8", errors="ignore"
|
||||
)
|
||||
txt = txt.strip()
|
||||
if result.returncode != 0 and not ignore:
|
||||
global errors
|
||||
errors += 1
|
||||
log.error(f'Error running git: {folder} / {arg}')
|
||||
if 'or stash them' in txt:
|
||||
log.error(f'Local changes detected: check log for details...')
|
||||
log.debug(f'Git output: {txt}')
|
||||
log.error(f"Error running git: {folder} / {arg}")
|
||||
if "or stash them" in txt:
|
||||
log.error(f"Local changes detected: check log for details...")
|
||||
log.debug(f"Git output: {txt}")
|
||||
|
||||
|
||||
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
|
||||
|
|
@ -473,32 +421,44 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool =
|
|||
Returns:
|
||||
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
|
||||
"""
|
||||
# arg = arg.replace('>=', '==')
|
||||
log.debug(f"Running pip command: {arg}")
|
||||
if not quiet:
|
||||
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
|
||||
pip_cmd = [fr"{sys.executable}", "-m", "pip"] + arg.split(" ")
|
||||
log.info(
|
||||
f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}'
|
||||
)
|
||||
pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ")
|
||||
log.debug(f"Running pip: {pip_cmd}")
|
||||
if show_stdout:
|
||||
subprocess.run(pip_cmd, shell=False, check=False, env=os.environ)
|
||||
else:
|
||||
result = subprocess.run(pip_cmd, shell=False, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
result = subprocess.run(
|
||||
pip_cmd,
|
||||
shell=False,
|
||||
check=False,
|
||||
env=os.environ,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||
encoding="utf8", errors="ignore"
|
||||
)
|
||||
txt = txt.strip()
|
||||
if result.returncode != 0 and not ignore:
|
||||
global errors # pylint: disable=global-statement
|
||||
global errors # pylint: disable=global-statement
|
||||
errors += 1
|
||||
log.error(f'Error running pip: {arg}')
|
||||
log.error(f'Pip output: {txt}')
|
||||
log.error(f"Error running pip: {arg}")
|
||||
log.error(f"Pip output: {txt}")
|
||||
return txt
|
||||
|
||||
|
||||
def installed(package, friendly: str = None):
|
||||
"""
|
||||
Checks if the specified package(s) are installed with the correct version.
|
||||
This function can handle package specifications with or without version constraints,
|
||||
and can also filter out command-line options and URLs when a 'friendly' string is provided.
|
||||
|
||||
|
||||
Parameters:
|
||||
- package: A string that specifies one or more packages with optional version constraints.
|
||||
- friendly: An optional string used to provide a cleaner version of the package string
|
||||
|
|
@ -506,43 +466,39 @@ def installed(package, friendly: str = None):
|
|||
|
||||
Returns:
|
||||
- True if all specified packages are installed with the correct versions, False otherwise.
|
||||
|
||||
|
||||
Note:
|
||||
This function was adapted from code written by vladimandic.
|
||||
"""
|
||||
|
||||
log.debug(f"Checking if package is installed: {package}")
|
||||
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
|
||||
package = re.sub(r'\[.*?\]', '', package)
|
||||
package = re.sub(r"\[.*?\]", "", package)
|
||||
|
||||
try:
|
||||
if friendly:
|
||||
# If a 'friendly' version of the package string is provided, split it into components
|
||||
pkgs = friendly.split()
|
||||
|
||||
|
||||
# Filter out command-line options and URLs from the package specification
|
||||
pkgs = [
|
||||
p
|
||||
for p in package.split()
|
||||
if not p.startswith('--') and "://" not in p
|
||||
p for p in package.split() if not p.startswith("--") and "://" not in p
|
||||
]
|
||||
else:
|
||||
# Split the package string into components, excluding '-' and '=' prefixed items
|
||||
pkgs = [
|
||||
p
|
||||
for p in package.split()
|
||||
if not p.startswith('-') and not p.startswith('=')
|
||||
if not p.startswith("-") and not p.startswith("=")
|
||||
]
|
||||
# For each package component, extract the package name, excluding any URLs
|
||||
pkgs = [
|
||||
p.split('/')[-1] for p in pkgs
|
||||
]
|
||||
pkgs = [p.split("/")[-1] for p in pkgs]
|
||||
|
||||
for pkg in pkgs:
|
||||
# Parse the package name and version based on the version specifier used
|
||||
if '>=' in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
|
||||
elif '==' in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
|
||||
if ">=" in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")]
|
||||
elif "==" in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split("==")]
|
||||
else:
|
||||
pkg_name, pkg_version = pkg.strip(), None
|
||||
|
||||
|
|
@ -553,38 +509,41 @@ def installed(package, friendly: str = None):
|
|||
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
||||
if spec is None:
|
||||
# Try replacing underscores with dashes
|
||||
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
|
||||
spec = pkg_resources.working_set.by_key.get(
|
||||
pkg_name.replace("_", "-"), None
|
||||
)
|
||||
|
||||
if spec is not None:
|
||||
# Package is found, check version
|
||||
version = pkg_resources.get_distribution(pkg_name).version
|
||||
log.debug(f'Package version found: {pkg_name} {version}')
|
||||
log.debug(f"Package version found: {pkg_name} {version}")
|
||||
|
||||
if pkg_version is not None:
|
||||
# Verify if the installed version meets the specified constraints
|
||||
if '>=' in pkg:
|
||||
if ">=" in pkg:
|
||||
ok = version >= pkg_version
|
||||
else:
|
||||
ok = version == pkg_version
|
||||
|
||||
if not ok:
|
||||
# Version mismatch, log warning and return False
|
||||
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
|
||||
log.warning(
|
||||
f"Package wrong version: {pkg_name} {version} required {pkg_version}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Package not found, log debug message and return False
|
||||
log.debug(f'Package version not found: {pkg_name}')
|
||||
log.debug(f"Package version not found: {pkg_name}")
|
||||
return False
|
||||
|
||||
# All specified packages are installed with the correct versions
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
# One or more packages are not installed, log debug message and return False
|
||||
log.debug(f'Package not installed: {pkgs}')
|
||||
log.debug(f"Package not installed: {pkgs}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# install package using pip if not already installed
|
||||
def install(
|
||||
package,
|
||||
|
|
@ -596,7 +555,7 @@ def install(
|
|||
"""
|
||||
Installs or upgrades a Python package using pip, with options to ignode errors,
|
||||
reinstall packages, and display outputs.
|
||||
|
||||
|
||||
Parameters:
|
||||
- package (str): The name of the package to be installed or upgraded. Can include
|
||||
version specifiers. Anything after a '#' in the package name will be ignored.
|
||||
|
|
@ -612,103 +571,98 @@ def install(
|
|||
Returns:
|
||||
None. The function performs operations that affect the environment but does not return
|
||||
any value.
|
||||
|
||||
|
||||
Note:
|
||||
If `reinstall` is True, it disables any mechanism that allows for skipping installations
|
||||
when the package is already present, forcing a fresh install.
|
||||
"""
|
||||
log.debug(f"Installing package: {package}")
|
||||
# Remove anything after '#' in the package variable
|
||||
package = package.split('#')[0].strip()
|
||||
package = package.split("#")[0].strip()
|
||||
|
||||
if reinstall:
|
||||
global quick_allowed # pylint: disable=global-statement
|
||||
global quick_allowed # pylint: disable=global-statement
|
||||
quick_allowed = False
|
||||
if reinstall or not installed(package, friendly):
|
||||
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
|
||||
pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def process_requirements_line(line, show_stdout: bool = False):
|
||||
log.debug(f"Processing requirements line: {line}")
|
||||
# Remove brackets and their contents from the line using regular expressions
|
||||
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
||||
package_name = re.sub(r'\[.*?\]', '', line)
|
||||
package_name = re.sub(r"\[.*?\]", "", line)
|
||||
install(line, package_name, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
|
||||
if check_no_verify_flag:
|
||||
log.info(f'Verifying modules installation status from {requirements_file}...')
|
||||
else:
|
||||
log.info(f'Installing modules from {requirements_file}...')
|
||||
with open(requirements_file, 'r', encoding='utf8') as f:
|
||||
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
|
||||
if check_no_verify_flag:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() != ''
|
||||
and not line.startswith('#')
|
||||
and line is not None
|
||||
and 'no_verify' not in line
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() != ''
|
||||
and not line.startswith('#')
|
||||
and line is not None
|
||||
]
|
||||
def install_requirements(
|
||||
requirements_file, check_no_verify_flag=False, show_stdout: bool = False
|
||||
):
|
||||
"""
|
||||
Install or verify modules from a requirements file.
|
||||
|
||||
# Iterate over each line and install the requirements
|
||||
for line in lines:
|
||||
# Check if the line starts with '-r' to include another requirements file
|
||||
if line.startswith('-r'):
|
||||
# Get the path to the included requirements file
|
||||
included_file = line[2:].strip()
|
||||
# Expand the included requirements file recursively
|
||||
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
|
||||
else:
|
||||
process_requirements_line(line, show_stdout=show_stdout)
|
||||
Parameters:
|
||||
- requirements_file (str): Path to the requirements file.
|
||||
- check_no_verify_flag (bool): If True, verify modules installation status without installing.
|
||||
- show_stdout (bool): If True, show the standard output of the installation process.
|
||||
"""
|
||||
log.debug(f"Installing requirements from file: {requirements_file}")
|
||||
action = "Verifying" if check_no_verify_flag else "Installing"
|
||||
log.info(f"{action} modules from {requirements_file}...")
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as f:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() and not line.startswith("#") and "no_verify" not in line
|
||||
]
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("-r"):
|
||||
included_file = line[2:].strip()
|
||||
log.debug(f"Processing included requirements file: {included_file}")
|
||||
install_requirements(
|
||||
included_file,
|
||||
check_no_verify_flag=check_no_verify_flag,
|
||||
show_stdout=show_stdout,
|
||||
)
|
||||
else:
|
||||
process_requirements_line(line, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def ensure_base_requirements():
|
||||
try:
|
||||
import rich # pylint: disable=unused-import
|
||||
import rich # pylint: disable=unused-import
|
||||
except ImportError:
|
||||
install('--upgrade rich', 'rich')
|
||||
|
||||
install("--upgrade rich", "rich")
|
||||
|
||||
try:
|
||||
import packaging
|
||||
except ImportError:
|
||||
install('packaging')
|
||||
install("packaging")
|
||||
|
||||
|
||||
def run_cmd(run_cmd):
|
||||
"""
|
||||
Execute a command using subprocess.
|
||||
"""
|
||||
log.debug(f"Running command: {run_cmd}")
|
||||
try:
|
||||
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
|
||||
subprocess.run(run_cmd, shell=True, check=True, env=os.environ)
|
||||
log.info(f"Command executed successfully: {run_cmd}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f'Error occurred while running command: {run_cmd}')
|
||||
log.error(f'Error: {e}')
|
||||
|
||||
|
||||
def delete_file(file_path):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
def write_to_file(file_path, content):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
file.write(content)
|
||||
except IOError as e:
|
||||
print(f'Error occurred while writing to file: {file_path}')
|
||||
print(f'Error: {e}')
|
||||
log.error(f"Error occurred while running command: {run_cmd}")
|
||||
log.error(f"Error: {e}")
|
||||
|
||||
|
||||
def clear_screen():
|
||||
# Check the current operating system to execute the correct clear screen command
|
||||
if os.name == 'nt': # If the operating system is Windows
|
||||
os.system('cls')
|
||||
else: # If the operating system is Linux or Mac
|
||||
os.system('clear')
|
||||
|
||||
"""
|
||||
Clear the terminal screen.
|
||||
"""
|
||||
log.debug("Attempting to clear the terminal screen")
|
||||
try:
|
||||
os.system("cls" if os.name == "nt" else "clear")
|
||||
log.info("Terminal screen cleared successfully")
|
||||
except Exception as e:
|
||||
log.error("Error occurred while clearing the terminal screen")
|
||||
log.error(f"Error: {e}")
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@ import argparse
|
|||
import setup_common
|
||||
|
||||
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
||||
project_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Check if the "setup" directory is present in the project_directory
|
||||
if "setup" in project_directory:
|
||||
# If the "setup" directory is present, move one level up to the parent directory
|
||||
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
project_directory = (
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if "setup" in os.path.dirname(os.path.abspath(__file__))
|
||||
else os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
|
||||
# Add the project directory to the beginning of the Python search path
|
||||
sys.path.insert(0, project_directory)
|
||||
|
|
@ -19,115 +18,172 @@ from kohya_gui.custom_logging import setup_logging
|
|||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
log.debug(f"Project directory set to: {project_directory}")
|
||||
|
||||
def check_path_with_space():
|
||||
# Get the current working directory
|
||||
"""Check if the current working directory contains a space."""
|
||||
cwd = os.getcwd()
|
||||
|
||||
# Check if the current working directory contains a space
|
||||
log.debug(f"Current working directory: {cwd}")
|
||||
if " " in cwd:
|
||||
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
|
||||
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
|
||||
log.error("The current working directory is: " + cwd)
|
||||
exit(1)
|
||||
# Log an error if the current working directory contains spaces
|
||||
log.error(
|
||||
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
|
||||
)
|
||||
log.error(
|
||||
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
|
||||
)
|
||||
log.error(f"The current working directory is: {cwd}")
|
||||
raise RuntimeError("Invalid path: contains spaces.")
|
||||
|
||||
def check_torch():
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
def detect_toolkit():
|
||||
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
|
||||
log.debug("Detecting available toolkit...")
|
||||
# Check for NVIDIA toolkit by looking for nvidia-smi executable
|
||||
if shutil.which("nvidia-smi") or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
'System32',
|
||||
'nvidia-smi.exe',
|
||||
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
|
||||
)
|
||||
):
|
||||
log.info('nVidia toolkit detected')
|
||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
||||
'/opt/rocm/bin/rocminfo'
|
||||
log.debug("nVidia toolkit detected")
|
||||
return "nVidia"
|
||||
# Check for AMD toolkit by looking for rocminfo executable
|
||||
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
|
||||
log.debug("AMD toolkit detected")
|
||||
return "AMD"
|
||||
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
|
||||
elif (
|
||||
shutil.which("sycl-ls")
|
||||
or os.environ.get("ONEAPI_ROOT")
|
||||
or os.path.exists("/opt/intel/oneapi")
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
log.debug("Intel toolkit detected")
|
||||
return "Intel"
|
||||
# Default to CPU if no toolkit is detected
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
log.debug("No specific GPU toolkit detected, defaulting to CPU")
|
||||
return "CPU"
|
||||
|
||||
def check_torch():
|
||||
"""Check if torch is available and log the relevant information."""
|
||||
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
|
||||
toolkit = detect_toolkit()
|
||||
log.info(f"{toolkit} toolkit detected")
|
||||
|
||||
try:
|
||||
# Import PyTorch
|
||||
log.debug("Importing PyTorch...")
|
||||
import torch
|
||||
try:
|
||||
# Import IPEX / XPU support
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
ipex = None
|
||||
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
|
||||
if toolkit == "Intel":
|
||||
try:
|
||||
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
|
||||
import intel_extension_for_pytorch as ipex
|
||||
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
|
||||
except ImportError:
|
||||
log.warning("Intel Extension for PyTorch (IPEX) not found.")
|
||||
|
||||
# Log the PyTorch version
|
||||
log.info(f"Torch {torch.__version__}")
|
||||
|
||||
# Check if CUDA (NVIDIA GPU) is available
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
|
||||
# Log information about detected GPUs
|
||||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
# Check if XPU is available
|
||||
log.debug("CUDA is available, logging CUDA info...")
|
||||
log_cuda_info(torch)
|
||||
# Check if XPU (Intel GPU) is available
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
for device in [
|
||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
log.debug("XPU is available, logging XPU info...")
|
||||
log_xpu_info(torch, ipex)
|
||||
# Log a warning if no GPU is available
|
||||
else:
|
||||
log.warning('Torch reports GPU not available')
|
||||
|
||||
log.warning("Torch reports GPU not available")
|
||||
|
||||
# Return the major version of PyTorch
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
log.error(f'Could not load torch: {e}')
|
||||
except ImportError as e:
|
||||
# Log an error if PyTorch cannot be loaded
|
||||
log.error(f"Could not load torch: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
# Log an unexpected error
|
||||
log.error(f"Unexpected error while checking torch: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def log_cuda_info(torch):
|
||||
"""Log information about CUDA-enabled GPUs."""
|
||||
# Log the CUDA and cuDNN versions if available
|
||||
if torch.version.cuda:
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
# Log the ROCm HIP version if using AMD GPU
|
||||
elif torch.version.hip:
|
||||
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||
else:
|
||||
log.warning("Unknown Torch backend")
|
||||
|
||||
# Log information about each detected CUDA-enabled GPU
|
||||
for device in range(torch.cuda.device_count()):
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
log.info(
|
||||
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
|
||||
)
|
||||
|
||||
def log_xpu_info(torch, ipex):
|
||||
"""Log information about Intel XPU-enabled GPUs."""
|
||||
# Log the Intel Extension for PyTorch (IPEX) version if available
|
||||
if ipex:
|
||||
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||
# Log information about each detected XPU-enabled GPU
|
||||
for device in range(torch.xpu.device_count()):
|
||||
props = torch.xpu.get_device_properties(device)
|
||||
log.info(
|
||||
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
|
||||
)
|
||||
|
||||
def main():
|
||||
# Check the repository version to ensure compatibility
|
||||
log.debug("Checking repository version...")
|
||||
setup_common.check_repo_version()
|
||||
|
||||
# Check if the current path contains spaces, which are not supported
|
||||
log.debug("Checking if the current path contains spaces...")
|
||||
check_path_with_space()
|
||||
|
||||
|
||||
# Parse command line arguments
|
||||
log.debug("Parsing command line arguments...")
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate that requirements are satisfied.'
|
||||
description="Validate that requirements are satisfied."
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
'--requirements',
|
||||
type=str,
|
||||
help='Path to the requirements file.',
|
||||
"-r", "--requirements", type=str, help="Path to the requirements file."
|
||||
)
|
||||
parser.add_argument('--debug', action='store_true', help='Debug on')
|
||||
parser.add_argument("--debug", action="store_true", help="Debug on")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Update git submodules if necessary
|
||||
log.debug("Updating git submodules...")
|
||||
setup_common.update_submodule()
|
||||
|
||||
# Check if PyTorch is installed and log relevant information
|
||||
log.debug("Checking if PyTorch is installed...")
|
||||
torch_ver = check_torch()
|
||||
|
||||
if not setup_common.check_python_version():
|
||||
exit(1)
|
||||
|
||||
if args.requirements:
|
||||
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
|
||||
else:
|
||||
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
|
||||
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Check if the Python version is compatible
|
||||
log.debug("Checking Python version...")
|
||||
if not setup_common.check_python_version():
|
||||
sys.exit(1)
|
||||
|
||||
# Install required packages from the specified requirements file
|
||||
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
|
||||
log.debug(f"Installing requirements from: {requirements_file}")
|
||||
setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
|
||||
log.debug("Installing additional requirements from: requirements_windows.txt")
|
||||
setup_common.install_requirements(
|
||||
"requirements_windows.txt", check_no_verify_flag=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.debug("Starting main function...")
|
||||
main()
|
||||
log.debug("Main function finished.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue