lora-scripts/mikazuki/utils/devices.py

33 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from mikazuki.log import log
available_devices = []
printable_devices = []
def check_torch_gpu():
try:
import torch
log.info(f'Torch {torch.__version__}')
if not torch.cuda.is_available():
log.error("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check")
log.error("Torch 无法使用 GPU您无法正常开始训练\n您的显卡可能并不支持,或是 torch 安装有误。请检查您的 torch 安装。")
return
if torch.version.cuda:
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
elif torch.version.hip:
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
for pos, device in enumerate(devices):
name = torch.cuda.get_device_name(device)
memory = torch.cuda.get_device_properties(device).total_memory
available_devices.append(device)
printable_devices.append(f"GPU {pos}: {name} ({round(memory / (1024**3))} GB)")
log.info(
f'Torch detected GPU: {name} VRAM {round(memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}')
except Exception as e:
log.error(f'Could not load torch: {e}')