lora-scripts/mikazuki/utils/devices.py

42 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from mikazuki.log import log
from packaging.version import Version
available_devices = []
printable_devices = []
def check_torch_gpu():
try:
import torch
log.info(f'Torch {torch.__version__}')
if not torch.cuda.is_available():
log.error("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check")
log.error("Torch 无法使用 GPU您无法正常开始训练\n您的显卡可能并不支持,或是 torch 安装有误。请检查您的 torch 安装。")
if "cpu" in torch.__version__:
log.error("You are using torch CPU, please install torch GPU version by run install script again.")
log.error("!!!您正在使用 CPU 版本的 torch无法正常开始训练。请重新运行安装脚本")
return
if Version(torch.__version__) < Version("2.3.0"):
log.warning("Torch version is lower than 2.3.0, which may not be able to train FLUX model properly. Please re-run the installation script (install.ps1 or install.bash) to upgrade Torch.")
log.warning("Torch 版本低于 2.3.0,将无法正常训练 FLUX 模型。请考虑重新运行安装脚本以升级 Torch")
log.warning("!!!若您正在使用训练包,请直接下载最新训练包!!!")
if torch.version.cuda:
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
elif torch.version.hip:
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
for pos, device in enumerate(devices):
name = torch.cuda.get_device_name(device)
memory = torch.cuda.get_device_properties(device).total_memory
available_devices.append(device)
printable_devices.append(f"GPU {pos}: {name} ({round(memory / (1024**3))} GB)")
log.info(
f'Torch detected GPU: {name} VRAM {round(memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}')
except Exception as e:
log.error(f'Could not load torch: {e}')