mirror of https://github.com/bmaltais/kohya_ss
Autodetect for ROCm
parent
55a11895ae
commit
082adb9e65
2
gui.sh
2
gui.sh
|
|
@ -74,7 +74,7 @@ else
|
|||
if [ "$RUNPOD" = false ]; then
|
||||
if [[ "$@" == *"--use-ipex"* ]]; then
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt"
|
||||
elif [[ "$@" == *"--use-rocm"* ]]; then
|
||||
elif [[ "$@" == *"--use-rocm"* ]] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_rocm.txt"
|
||||
else
|
||||
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt"
|
||||
|
|
|
|||
2
setup.sh
2
setup.sh
|
|
@ -209,7 +209,7 @@ install_python_dependencies() {
|
|||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
|
||||
elif [ "$USE_IPEX" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
|
||||
elif [ "$USE_ROCM" = true ]; then
|
||||
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt
|
||||
else
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ def check_torch():
|
|||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
#
|
||||
|
||||
# Check for nVidia toolkit or AMD toolkit
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
|
|
@ -353,29 +353,18 @@ def check_torch():
|
|||
try:
|
||||
import torch
|
||||
try:
|
||||
# Import IPEX / XPU support
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
os.environ.setdefault('NEOReadDebugKeys', '1')
|
||||
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
log.warning('Torch reports CUDA not available')
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda:
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX OneAPI version
|
||||
log.info(f'Torch backend: Intel IPEX OneAPI {ipex.__version__}')
|
||||
else:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
|
|
@ -386,15 +375,23 @@ def check_torch():
|
|||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
return int(torch.__version__[0])
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
# Check if XPU is available
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
for device in [
|
||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
else:
|
||||
log.warning('Torch reports GPU not available')
|
||||
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
# log.warning(f'Could not load torch: {e}')
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from kohya_gui.custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
def check_torch():
|
||||
# Check for nVidia toolkit or AMD toolkit
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
|
|
|
|||
Loading…
Reference in New Issue