diff --git a/modules/rocm_triton_windows.py b/modules/rocm_triton_windows.py new file mode 100644 index 000000000..b55e5f488 --- /dev/null +++ b/modules/rocm_triton_windows.py @@ -0,0 +1,127 @@ +import sys +from functools import wraps +import torch +from modules import shared + + +if sys.platform == "win32": + MEM_BUS_WIDTH = { + "AMD Radeon RX 9070 XT": 256, + "AMD Radeon RX 9070": 256, + "AMD Radeon RX 9060 XT": 192, + "AMD Radeon RX 7900 XTX": 384, + "AMD Radeon RX 7900 XT": 320, + "AMD Radeon RX 7900 GRE": 256, + "AMD Radeon RX 7800 XT": 256, + "AMD Radeon RX 7700 XT": 192, + "AMD Radeon RX 7700": 192, + "AMD Radeon RX 7650 GRE": 128, + "AMD Radeon RX 7600 XT": 128, + "AMD Radeon RX 7600": 128, + "AMD Radeon RX 7500 XT": 96, + "AMD Radeon RX 6950 XT": 256, + "AMD Radeon RX 6900 XT": 256, + "AMD Radeon RX 6800 XT": 256, + "AMD Radeon RX 6800": 256, + "AMD Radeon RX 6750 XT": 192, + "AMD Radeon RX 6700 XT": 192, + "AMD Radeon RX 6700": 160, + "AMD Radeon RX 6650 XT": 128, + "AMD Radeon RX 6600 XT": 128, + "AMD Radeon RX 6600": 128, + "AMD Radeon RX 6500 XT": 64, + "AMD Radeon RX 6400": 64, + } + + class DeviceProperties: + PROPERTIES_OVERRIDE = { + # sometimes gcnArchName contains device name ("AMD Radeon RX ..."), not architecture name ("gfx...") + "gcnArchName": "UNKNOWN ARCHITECTURE", + } + internal: torch._C._CudaDeviceProperties + + def __init__(self, props: torch._C._CudaDeviceProperties): + self.internal = props + + def __getattr__(self, name): + if name in DeviceProperties.PROPERTIES_OVERRIDE: + return DeviceProperties.PROPERTIES_OVERRIDE[name] + return getattr(self.internal, name) + + __get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access + def torch_cuda__get_device_properties(device): + return DeviceProperties(__get_device_properties(device)) + + _cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream # pylint: disable=protected-access + def torch__C__cuda_getCurrentRawStream(device): + from modules import zluda + return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device)) + + def get_default_agent_name(): + if shared.devices.backend == "rocm": + device = shared.devices.get_optimal_device() + return getattr(torch.cuda.get_device_properties(device), "gcnArchName", None) + else: + from modules import zluda + if zluda.default_agent is None: + return None + return zluda.default_agent.name + + def apply_triton_patches(): + arch_name = get_default_agent_name() + if arch_name is not None: + DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = arch_name + torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access + if shared.devices.backend == "zluda": + torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access + torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access + + # Triton + try: + import triton + _get_device_properties = triton.runtime.driver.active.utils.get_device_properties + def triton_runtime_driver_active_utils_get_device_properties(device): + props = _get_device_properties(device) + name = torch.cuda.get_device_name() + if shared.devices.has_zluda(): + name = name[:-8] + if props["mem_bus_width"] == 0: # Windows HIP SDK bug + if name in MEM_BUS_WIDTH: + props["mem_bus_width"] = MEM_BUS_WIDTH[name] + else: + props["mem_bus_width"] = 128 + shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".') + return props + triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties + + if 'Flash attention' in shared.opts.sdp_options: + from modules.flash_attn_triton_amd import interface_fa + sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention + @wraps(sdpa_pre_flash_atten) + def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: + if scale is None: + scale = query.shape[-1] ** (-0.5) + head_size_og = query.size(3) + if head_size_og % 8 != 0: + query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) + key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) + value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) + query = query.transpose(1, 2) + out_padded = torch.zeros_like(query) + interface_fa.fwd( + query, + key.transpose(1, 2), + value.transpose(1, 2), + out_padded, + dropout_p, + scale, + is_causal, + ) + return out_padded[..., :head_size_og].transpose(1, 2) + else: + return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten + shared.log.debug('Torch attention: type="triton flash attention"') + except Exception: + pass diff --git a/modules/shared.py b/modules/shared.py index c05c0a84e..58d58c535 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1281,9 +1281,13 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device) history = history.History() if devices.backend == "directml": directml_do_hijack() -elif devices.backend == "zluda": - from modules.zluda import initialize_zluda - initialize_zluda() +elif sys.platform == "win32" and (devices.backend == "zluda" or devices.backend == "rocm"): + from modules.rocm_triton_windows import apply_triton_patches + apply_triton_patches() + + if devices.backend == "zluda": + from modules.zluda import initialize_zluda + initialize_zluda() try: log.info(f'Device: {print_dict(devices.get_gpu_info())}') except Exception as ex: diff --git a/modules/zluda.py b/modules/zluda.py index 330377ab8..8c7802c38 100644 --- a/modules/zluda.py +++ b/modules/zluda.py @@ -26,13 +26,9 @@ def test(device: DeviceLikeType) -> Union[Exception, None]: def initialize_zluda(): shared.cmd_opts.device_id = None - device = devices.get_optimal_device() - if not devices.cuda_ok or not devices.has_zluda() or PLATFORM != "win32": + if not devices.cuda_ok or not devices.has_zluda(): return - from modules.zluda_hijacks import do_hijack - do_hijack() - torch.backends.cudnn.enabled = zluda_installer.MIOpen_enabled if hasattr(torch.backends.cuda, "enable_cudnn_sdp"): if not zluda_installer.MIOpen_enabled: @@ -51,6 +47,7 @@ def initialize_zluda(): if shared.opts.onnx_execution_provider == ExecutionProvider.CUDA: shared.opts.onnx_execution_provider = ExecutionProvider.CPU + device = devices.get_optimal_device() result = test(device) if result is not None: shared.log.warning(f'ZLUDA device failed to pass basic operation test: index={device.index}, device_name={torch.cuda.get_device_name(device)}') diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py deleted file mode 100644 index f8d831986..000000000 --- a/modules/zluda_hijacks.py +++ /dev/null @@ -1,115 +0,0 @@ -from functools import wraps -import torch -import torch._dynamo.device_interface -from modules import shared, zluda # pylint: disable=unused-import - - -MEM_BUS_WIDTH = { - "AMD Radeon RX 9070 XT": 256, - "AMD Radeon RX 9070": 256, - "AMD Radeon RX 9060 XT": 192, - "AMD Radeon RX 7900 XTX": 384, - "AMD Radeon RX 7900 XT": 320, - "AMD Radeon RX 7900 GRE": 256, - "AMD Radeon RX 7800 XT": 256, - "AMD Radeon RX 7700 XT": 192, - "AMD Radeon RX 7700": 192, - "AMD Radeon RX 7650 GRE": 128, - "AMD Radeon RX 7600 XT": 128, - "AMD Radeon RX 7600": 128, - "AMD Radeon RX 7500 XT": 96, - "AMD Radeon RX 6950 XT": 256, - "AMD Radeon RX 6900 XT": 256, - "AMD Radeon RX 6800 XT": 256, - "AMD Radeon RX 6800": 256, - "AMD Radeon RX 6750 XT": 192, - "AMD Radeon RX 6700 XT": 192, - "AMD Radeon RX 6700": 160, - "AMD Radeon RX 6650 XT": 128, - "AMD Radeon RX 6600 XT": 128, - "AMD Radeon RX 6600": 128, - "AMD Radeon RX 6500 XT": 64, - "AMD Radeon RX 6400": 64, -} - - -class DeviceProperties: - PROPERTIES_OVERRIDE = { - # sometimes gcnArchName contains device name ("AMD Radeon RX ..."), not architecture name ("gfx...") - "gcnArchName": "UNKNOWN ARCHITECTURE", - } - internal: torch._C._CudaDeviceProperties - - def __init__(self, props: torch._C._CudaDeviceProperties): - self.internal = props - - def __getattr__(self, name): - if name in DeviceProperties.PROPERTIES_OVERRIDE: - return DeviceProperties.PROPERTIES_OVERRIDE[name] - return getattr(self.internal, name) - - -__get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access -def torch_cuda__get_device_properties(device): - return DeviceProperties(__get_device_properties(device)) - - -_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream # pylint: disable=protected-access -def torch__C__cuda_getCurrentRawStream(device): - return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device)) - - -def do_hijack(): - if zluda.default_agent is not None: - DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name - torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access - torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access - torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access - - # Triton - try: - import triton - _get_device_properties = triton.runtime.driver.active.utils.get_device_properties - def triton_runtime_driver_active_utils_get_device_properties(device): - props = _get_device_properties(device) - name = torch.cuda.get_device_name()[:-8] - if props["mem_bus_width"] == 0: # Windows HIP SDK bug - if name in MEM_BUS_WIDTH: - props["mem_bus_width"] = MEM_BUS_WIDTH[name] - else: - props["mem_bus_width"] = 128 - shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".') - return props - triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties - - if 'Flash attention' in shared.opts.sdp_options: - from modules.flash_attn_triton_amd import interface_fa - sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention - @wraps(sdpa_pre_flash_atten) - def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: - if scale is None: - scale = query.shape[-1] ** (-0.5) - head_size_og = query.size(3) - if head_size_og % 8 != 0: - query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) - key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) - value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) - query = query.transpose(1, 2) - out_padded = torch.zeros_like(query) - interface_fa.fwd( - query, - key.transpose(1, 2), - value.transpose(1, 2), - out_padded, - dropout_p, - scale, - is_causal, - ) - return out_padded[..., :head_size_og].transpose(1, 2) - else: - return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten - shared.log.debug('Torch attention: type="triton flash attention"') - except Exception: - pass