33 lines
1.6 KiB
Python
33 lines
1.6 KiB
Python
from mikazuki.log import log
|
||
|
||
available_devices = []
|
||
printable_devices = []
|
||
|
||
|
||
def check_torch_gpu():
|
||
try:
|
||
import torch
|
||
log.info(f'Torch {torch.__version__}')
|
||
if not torch.cuda.is_available():
|
||
log.error("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check")
|
||
log.error("!!!Torch 无法使用 GPU,您无法正常开始训练!!!\n您的显卡可能并不支持,或是 torch 安装有误。请检查您的 torch 安装。")
|
||
return
|
||
|
||
if torch.version.cuda:
|
||
log.info(
|
||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
|
||
elif torch.version.hip:
|
||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||
|
||
devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
|
||
|
||
for pos, device in enumerate(devices):
|
||
name = torch.cuda.get_device_name(device)
|
||
memory = torch.cuda.get_device_properties(device).total_memory
|
||
available_devices.append(device)
|
||
printable_devices.append(f"GPU {pos}: {name} ({round(memory / (1024**3))} GB)")
|
||
log.info(
|
||
f'Torch detected GPU: {name} VRAM {round(memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}')
|
||
except Exception as e:
|
||
log.error(f'Could not load torch: {e}')
|