kohya_ss/setup/validate_requirements.py

196 lines
7.5 KiB
Python

import os
import sys
import shutil
import argparse
import setup_common
# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = (
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if "setup" in os.path.dirname(os.path.abspath(__file__))
else os.path.dirname(os.path.abspath(__file__))
)
# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)
from kohya_gui.custom_logging import setup_logging
# Set up logging
log = setup_logging()
log.debug(f"Project directory set to: {project_directory}")
def check_path_with_space():
"""Check if the current working directory contains a space."""
cwd = os.getcwd()
log.debug(f"Current working directory: {cwd}")
if " " in cwd:
# Log an error if the current working directory contains spaces
log.error(
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
)
log.error(
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
)
log.error(f"The current working directory is: {cwd}")
raise RuntimeError("Invalid path: contains spaces.")
def detect_toolkit():
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
log.debug("Detecting available toolkit...")
# Check for NVIDIA toolkit by looking for nvidia-smi executable
if shutil.which("nvidia-smi") or os.path.exists(
os.path.join(
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
)
):
log.debug("nVidia toolkit detected")
return "nVidia"
# Check for AMD toolkit by looking for rocminfo executable
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
log.debug("AMD toolkit detected")
return "AMD"
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
elif (
shutil.which("sycl-ls")
or os.environ.get("ONEAPI_ROOT")
or os.path.exists("/opt/intel/oneapi")
):
log.debug("Intel toolkit detected")
return "Intel"
# Default to CPU if no toolkit is detected
else:
log.debug("No specific GPU toolkit detected, defaulting to CPU")
return "CPU"
def check_torch():
"""Check if torch is available and log the relevant information."""
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
toolkit = detect_toolkit()
log.info(f"{toolkit} toolkit detected")
try:
# Import PyTorch
log.debug("Importing PyTorch...")
import torch
ipex = None
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
if toolkit == "Intel":
try:
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
import intel_extension_for_pytorch as ipex
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
except ImportError:
log.warning("Intel Extension for PyTorch (IPEX) not found.")
# Log the PyTorch version
log.info(f"Torch {torch.__version__}")
# Check if CUDA (NVIDIA GPU) is available
if torch.cuda.is_available():
log.debug("CUDA is available, logging CUDA info...")
log_cuda_info(torch)
# Check if XPU (Intel GPU) is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
log.debug("XPU is available, logging XPU info...")
log_xpu_info(torch, ipex)
# Log a warning if no GPU is available
else:
log.warning("Torch reports GPU not available")
# Return the major version of PyTorch
return int(torch.__version__[0])
except ImportError as e:
# Log an error if PyTorch cannot be loaded
log.error(f"Could not load torch: {e}")
sys.exit(1)
except Exception as e:
# Log an unexpected error
log.error(f"Unexpected error while checking torch: {e}")
sys.exit(1)
def log_cuda_info(torch):
"""Log information about CUDA-enabled GPUs."""
# Log the CUDA and cuDNN versions if available
if torch.version.cuda:
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
# Log the ROCm HIP version if using AMD GPU
elif torch.version.hip:
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
else:
log.warning("Unknown Torch backend")
# Log information about each detected CUDA-enabled GPU
for device in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
)
def log_xpu_info(torch, ipex):
"""Log information about Intel XPU-enabled GPUs."""
# Log the Intel Extension for PyTorch (IPEX) version if available
if ipex:
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
# Log information about each detected XPU-enabled GPU
for device in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
)
def main():
# Check the repository version to ensure compatibility
log.debug("Checking repository version...")
setup_common.check_repo_version()
# Check if the current path contains spaces, which are not supported
log.debug("Checking if the current path contains spaces...")
check_path_with_space()
# Parse command line arguments
log.debug("Parsing command line arguments...")
parser = argparse.ArgumentParser(
description="Validate that requirements are satisfied."
)
parser.add_argument(
"-r", "--requirements", type=str, help="Path to the requirements file."
)
parser.add_argument("--debug", action="store_true", help="Debug on")
args = parser.parse_args()
# Update git submodules if necessary
log.debug("Updating git submodules...")
setup_common.update_submodule()
# Check if PyTorch is installed and log relevant information
log.debug("Checking if PyTorch is installed...")
torch_ver = check_torch()
# Check if the Python version is compatible
log.debug("Checking Python version...")
if not setup_common.check_python_version():
sys.exit(1)
# Install required packages from the specified requirements file
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
log.debug(f"Installing requirements from: {requirements_file}")
setup_common.install_requirements_inbulk(
requirements_file, show_stdout=True,
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
)
# setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
# log.debug("Installing additional requirements from: requirements_windows.txt")
# setup_common.install_requirements(
# "requirements_windows.txt", check_no_verify_flag=True
# )
if __name__ == "__main__":
log.debug("Starting main function...")
main()
log.debug("Main function finished.")