mirror of https://github.com/vladmandic/automatic
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
# pylint: disable=no-member,no-self-argument,no-method-argument
|
|
import torch
|
|
import torch_directml # pylint: disable=import-error
|
|
import modules.dml.hijack
|
|
import modules.dml.amp as amp
|
|
|
|
from .optimizer.unknown import UnknownOptimizer
|
|
|
|
class DirectML():
|
|
_is_autocast_enabled = False
|
|
_autocast_dtype = torch.float16
|
|
|
|
def get_optimizer(device: torch.device):
|
|
assert device.type == 'privateuseone'
|
|
try:
|
|
device_name = torch_directml.device_name(device.index)
|
|
if 'NVIDIA' in device_name or 'GeForce' in device_name:
|
|
from .optimizer.nvidia import nVidiaOptimizer as optimizer
|
|
elif 'AMD' in device_name or 'Radeon' in device_name:
|
|
from .optimizer.amd import AMDOptimizer as optimizer
|
|
elif 'Intel' in device_name:
|
|
from .optimizer.intel import IntelOptimizer as optimizer
|
|
else:
|
|
return UnknownOptimizer
|
|
return optimizer
|
|
except Exception:
|
|
return UnknownOptimizer
|
|
|
|
def memory_stats(device: torch.device):
|
|
optimizer = DirectML.get_optimizer(device)
|
|
return optimizer.memory_stats(device.index)
|
|
|
|
def get_autocast_gpu_dtype():
|
|
return DirectML._autocast_dtype
|
|
|
|
def set_autocast_gpu_dtype(dtype):
|
|
DirectML._autocast_dtype = dtype
|
|
|
|
def is_autocast_enabled():
|
|
return DirectML._is_autocast_enabled
|
|
|
|
def set_autocast_enabled(enabled: bool):
|
|
DirectML._is_autocast_enabled = enabled
|
|
|
|
|
|
# Alternative of torch.cuda for DirectML.
|
|
DirectML.amp = amp
|
|
torch.dml = DirectML
|