mirror of https://github.com/vladmandic/automatic
101 lines
4.5 KiB
Python
101 lines
4.5 KiB
Python
from functools import wraps
|
|
import torch
|
|
import torch._dynamo.device_interface
|
|
from modules import shared, rocm, zluda # pylint: disable=unused-import
|
|
|
|
|
|
MEM_BUS_WIDTH = {
|
|
"AMD Radeon RX 9070 XT": 256,
|
|
"AMD Radeon RX 9070": 256,
|
|
"AMD Radeon RX 9060 XT": 192,
|
|
"AMD Radeon RX 7900 XTX": 384,
|
|
"AMD Radeon RX 7900 XT": 320,
|
|
"AMD Radeon RX 7900 GRE": 256,
|
|
"AMD Radeon RX 7800 XT": 256,
|
|
"AMD Radeon RX 7700 XT": 192,
|
|
"AMD Radeon RX 7600 XT": 128,
|
|
"AMD Radeon RX 7600": 128,
|
|
}
|
|
|
|
|
|
class DeviceProperties:
|
|
PROPERTIES_OVERRIDE = {"regs_per_multiprocessor": 65535, "gcnArchName": "UNKNOWN ARCHITECTURE"}
|
|
internal: torch._C._CudaDeviceProperties
|
|
|
|
def __init__(self, props: torch._C._CudaDeviceProperties):
|
|
self.internal = props
|
|
|
|
def __getattr__(self, name):
|
|
if name in DeviceProperties.PROPERTIES_OVERRIDE:
|
|
return DeviceProperties.PROPERTIES_OVERRIDE[name]
|
|
return getattr(self.internal, name)
|
|
|
|
|
|
__get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access
|
|
def torch_cuda__get_device_properties(device):
|
|
return DeviceProperties(__get_device_properties(device))
|
|
|
|
|
|
_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream # pylint: disable=protected-access
|
|
def torch__C__cuda_getCurrentRawStream(device):
|
|
return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device))
|
|
|
|
|
|
def do_hijack():
|
|
if zluda.default_agent is not None:
|
|
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
|
|
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
|
|
torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access
|
|
torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access
|
|
|
|
# Triton
|
|
try:
|
|
import triton
|
|
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
|
|
def triton_runtime_driver_active_utils_get_device_properties(device):
|
|
props = _get_device_properties(device)
|
|
name = torch.cuda.get_device_name()[:-8]
|
|
if name in MEM_BUS_WIDTH:
|
|
props["mem_bus_width"] = MEM_BUS_WIDTH[name]
|
|
else:
|
|
props["mem_bus_width"] = 128
|
|
shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".')
|
|
return props
|
|
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
|
|
|
|
if 'Flash attention' in shared.opts.sdp_options:
|
|
from modules.flash_attn_triton_amd import interface_fa
|
|
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
|
@wraps(sdpa_pre_flash_atten)
|
|
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
|
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
|
if scale is None:
|
|
scale = query.shape[-1] ** (-0.5)
|
|
head_size_og = query.size(3)
|
|
if head_size_og % 8 != 0:
|
|
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
|
|
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
|
|
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
|
|
out_padded, _, _, _ = interface_fa.fwd(
|
|
query.transpose(1, 2),
|
|
key.transpose(1, 2),
|
|
value.transpose(1, 2),
|
|
None,
|
|
None,
|
|
dropout_p,
|
|
scale,
|
|
is_causal,
|
|
-1,
|
|
-1,
|
|
0.0,
|
|
False,
|
|
None,
|
|
)
|
|
return out_padded[..., :head_size_og].transpose(1, 2)
|
|
else:
|
|
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
|
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
|
shared.log.debug('Torch attention: type="triton flash attention"')
|
|
except Exception:
|
|
pass
|