mirror of https://github.com/vladmandic/automatic
67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
import importlib
|
|
from typing import Any, Optional
|
|
import torch
|
|
|
|
|
|
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
|
|
supported_cast_pairs = {
|
|
torch.float16: (torch.float32,),
|
|
torch.float32: (torch.float16,),
|
|
}
|
|
|
|
|
|
def forward(op, args: tuple, kwargs: dict):
|
|
if not torch.dml.is_autocast_enabled:
|
|
return op(*args, **kwargs)
|
|
args = list(map(cast, args))
|
|
for kwarg in kwargs:
|
|
kwargs[kwarg] = cast(kwargs[kwarg])
|
|
return op(*args, **kwargs)
|
|
|
|
|
|
def cast(tensor: torch.Tensor):
|
|
if not torch.is_tensor(tensor):
|
|
return tensor
|
|
dtype: torch.dtype = tensor.dtype
|
|
if dtype not in supported_cast_pairs or (torch.dml.autocast_gpu_dtype != dtype and torch.dml.autocast_gpu_dtype not in supported_cast_pairs[dtype]):
|
|
return tensor
|
|
return tensor.type(torch.dml.autocast_gpu_dtype)
|
|
|
|
|
|
def cond(op: str):
|
|
if isinstance(op, str):
|
|
func_path = op.split('.')
|
|
for i in range(len(func_path)-1, -1, -1):
|
|
try:
|
|
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
|
break
|
|
except ImportError:
|
|
pass
|
|
for attr_name in func_path[i:-1]:
|
|
resolved_obj = getattr(resolved_obj, attr_name)
|
|
op = getattr(resolved_obj, func_path[-1])
|
|
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
|
|
|
|
|
|
for o in ops:
|
|
cond(o)
|
|
|
|
|
|
class autocast:
|
|
prev: bool
|
|
|
|
fast_dtype: torch.dtype = torch.float16
|
|
prev_fast_dtype: torch.dtype
|
|
def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
|
|
self.fast_dtype = dtype
|
|
|
|
def __enter__(self):
|
|
self.prev = torch.dml.is_autocast_enabled
|
|
self.prev_fast_dtype = torch.dml.autocast_gpu_dtype
|
|
torch.dml.is_autocast_enabled = True
|
|
torch.dml.autocast_gpu_dtype = self.fast_dtype
|
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
torch.dml.is_autocast_enabled = self.prev
|
|
torch.dml.autocast_gpu_dtype = self.prev_fast_dtype
|