mirror of https://github.com/bmaltais/kohya_ss
196 lines
7.5 KiB
Python
196 lines
7.5 KiB
Python
import os
|
|
import sys
|
|
import shutil
|
|
import argparse
|
|
import setup_common
|
|
|
|
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
|
project_directory = (
|
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
if "setup" in os.path.dirname(os.path.abspath(__file__))
|
|
else os.path.dirname(os.path.abspath(__file__))
|
|
)
|
|
|
|
# Add the project directory to the beginning of the Python search path
|
|
sys.path.insert(0, project_directory)
|
|
|
|
from kohya_gui.custom_logging import setup_logging
|
|
|
|
# Set up logging
|
|
log = setup_logging()
|
|
log.debug(f"Project directory set to: {project_directory}")
|
|
|
|
def check_path_with_space():
|
|
"""Check if the current working directory contains a space."""
|
|
cwd = os.getcwd()
|
|
log.debug(f"Current working directory: {cwd}")
|
|
if " " in cwd:
|
|
# Log an error if the current working directory contains spaces
|
|
log.error(
|
|
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
|
|
)
|
|
log.error(
|
|
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
|
|
)
|
|
log.error(f"The current working directory is: {cwd}")
|
|
raise RuntimeError("Invalid path: contains spaces.")
|
|
|
|
def detect_toolkit():
|
|
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
|
|
log.debug("Detecting available toolkit...")
|
|
# Check for NVIDIA toolkit by looking for nvidia-smi executable
|
|
if shutil.which("nvidia-smi") or os.path.exists(
|
|
os.path.join(
|
|
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
|
|
)
|
|
):
|
|
log.debug("nVidia toolkit detected")
|
|
return "nVidia"
|
|
# Check for AMD toolkit by looking for rocminfo executable
|
|
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
|
|
log.debug("AMD toolkit detected")
|
|
return "AMD"
|
|
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
|
|
elif (
|
|
shutil.which("sycl-ls")
|
|
or os.environ.get("ONEAPI_ROOT")
|
|
or os.path.exists("/opt/intel/oneapi")
|
|
):
|
|
log.debug("Intel toolkit detected")
|
|
return "Intel"
|
|
# Default to CPU if no toolkit is detected
|
|
else:
|
|
log.debug("No specific GPU toolkit detected, defaulting to CPU")
|
|
return "CPU"
|
|
|
|
def check_torch():
|
|
"""Check if torch is available and log the relevant information."""
|
|
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
|
|
toolkit = detect_toolkit()
|
|
log.info(f"{toolkit} toolkit detected")
|
|
|
|
try:
|
|
# Import PyTorch
|
|
log.debug("Importing PyTorch...")
|
|
import torch
|
|
|
|
ipex = None
|
|
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
|
|
if toolkit == "Intel":
|
|
try:
|
|
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
|
|
import intel_extension_for_pytorch as ipex
|
|
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
|
|
except ImportError:
|
|
log.warning("Intel Extension for PyTorch (IPEX) not found.")
|
|
|
|
# Log the PyTorch version
|
|
log.info(f"Torch {torch.__version__}")
|
|
|
|
# Check if CUDA (NVIDIA GPU) is available
|
|
if torch.cuda.is_available():
|
|
log.debug("CUDA is available, logging CUDA info...")
|
|
log_cuda_info(torch)
|
|
# Check if XPU (Intel GPU) is available
|
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
log.debug("XPU is available, logging XPU info...")
|
|
log_xpu_info(torch, ipex)
|
|
# Log a warning if no GPU is available
|
|
else:
|
|
log.warning("Torch reports GPU not available")
|
|
|
|
# Return the major version of PyTorch
|
|
return int(torch.__version__[0])
|
|
except ImportError as e:
|
|
# Log an error if PyTorch cannot be loaded
|
|
log.error(f"Could not load torch: {e}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
# Log an unexpected error
|
|
log.error(f"Unexpected error while checking torch: {e}")
|
|
sys.exit(1)
|
|
|
|
def log_cuda_info(torch):
|
|
"""Log information about CUDA-enabled GPUs."""
|
|
# Log the CUDA and cuDNN versions if available
|
|
if torch.version.cuda:
|
|
log.info(
|
|
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
|
)
|
|
# Log the ROCm HIP version if using AMD GPU
|
|
elif torch.version.hip:
|
|
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
|
else:
|
|
log.warning("Unknown Torch backend")
|
|
|
|
# Log information about each detected CUDA-enabled GPU
|
|
for device in range(torch.cuda.device_count()):
|
|
props = torch.cuda.get_device_properties(device)
|
|
log.info(
|
|
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
|
|
)
|
|
|
|
def log_xpu_info(torch, ipex):
|
|
"""Log information about Intel XPU-enabled GPUs."""
|
|
# Log the Intel Extension for PyTorch (IPEX) version if available
|
|
if ipex:
|
|
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
|
# Log information about each detected XPU-enabled GPU
|
|
for device in range(torch.xpu.device_count()):
|
|
props = torch.xpu.get_device_properties(device)
|
|
log.info(
|
|
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
|
|
)
|
|
|
|
def main():
|
|
# Check the repository version to ensure compatibility
|
|
log.debug("Checking repository version...")
|
|
setup_common.check_repo_version()
|
|
# Check if the current path contains spaces, which are not supported
|
|
log.debug("Checking if the current path contains spaces...")
|
|
check_path_with_space()
|
|
|
|
# Parse command line arguments
|
|
log.debug("Parsing command line arguments...")
|
|
parser = argparse.ArgumentParser(
|
|
description="Validate that requirements are satisfied."
|
|
)
|
|
parser.add_argument(
|
|
"-r", "--requirements", type=str, help="Path to the requirements file."
|
|
)
|
|
parser.add_argument("--debug", action="store_true", help="Debug on")
|
|
args = parser.parse_args()
|
|
|
|
# Update git submodules if necessary
|
|
log.debug("Updating git submodules...")
|
|
setup_common.update_submodule()
|
|
|
|
# Check if PyTorch is installed and log relevant information
|
|
log.debug("Checking if PyTorch is installed...")
|
|
torch_ver = check_torch()
|
|
|
|
# Check if the Python version is compatible
|
|
log.debug("Checking Python version...")
|
|
if not setup_common.check_python_version():
|
|
sys.exit(1)
|
|
|
|
# Install required packages from the specified requirements file
|
|
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
|
|
log.debug(f"Installing requirements from: {requirements_file}")
|
|
setup_common.install_requirements_inbulk(
|
|
requirements_file, show_stdout=True,
|
|
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
|
|
)
|
|
|
|
# setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
|
|
|
|
# log.debug("Installing additional requirements from: requirements_windows.txt")
|
|
# setup_common.install_requirements(
|
|
# "requirements_windows.txt", check_no_verify_flag=True
|
|
# )
|
|
|
|
if __name__ == "__main__":
|
|
log.debug("Starting main function...")
|
|
main()
|
|
log.debug("Main function finished.")
|