automatic/modules/dml/__init__.py

49 lines
1.6 KiB
Python

# pylint: disable=no-member,no-self-argument,no-method-argument
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.hijack
import modules.dml.amp as amp
from .optimizer.unknown import UnknownOptimizer
class DirectML():
_is_autocast_enabled = False
_autocast_dtype = torch.float16
def get_optimizer(device: torch.device):
assert device.type == 'privateuseone'
try:
device_name = torch_directml.device_name(device.index)
if 'NVIDIA' in device_name or 'GeForce' in device_name:
from .optimizer.nvidia import nVidiaOptimizer as optimizer
elif 'AMD' in device_name or 'Radeon' in device_name:
from .optimizer.amd import AMDOptimizer as optimizer
elif 'Intel' in device_name:
from .optimizer.intel import IntelOptimizer as optimizer
else:
return UnknownOptimizer
return optimizer
except Exception:
return UnknownOptimizer
def memory_stats(device: torch.device):
optimizer = DirectML.get_optimizer(device)
return optimizer.memory_stats(device.index)
def get_autocast_gpu_dtype():
return DirectML._autocast_dtype
def set_autocast_gpu_dtype(dtype):
DirectML._autocast_dtype = dtype
def is_autocast_enabled():
return DirectML._is_autocast_enabled
def set_autocast_enabled(enabled: bool):
DirectML._is_autocast_enabled = enabled
# Alternative of torch.cuda for DirectML.
DirectML.amp = amp
torch.dml = DirectML