basic windows native pytorch support

pull/3921/head^2
Seunghoon Lee 2025-05-09 22:23:07 +09:00
parent 51716bbaba
commit 45c0bd6ec6
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
4 changed files with 136 additions and 123 deletions

View File

@ -0,0 +1,127 @@
import sys
from functools import wraps
import torch
from modules import shared
if sys.platform == "win32":
MEM_BUS_WIDTH = {
"AMD Radeon RX 9070 XT": 256,
"AMD Radeon RX 9070": 256,
"AMD Radeon RX 9060 XT": 192,
"AMD Radeon RX 7900 XTX": 384,
"AMD Radeon RX 7900 XT": 320,
"AMD Radeon RX 7900 GRE": 256,
"AMD Radeon RX 7800 XT": 256,
"AMD Radeon RX 7700 XT": 192,
"AMD Radeon RX 7700": 192,
"AMD Radeon RX 7650 GRE": 128,
"AMD Radeon RX 7600 XT": 128,
"AMD Radeon RX 7600": 128,
"AMD Radeon RX 7500 XT": 96,
"AMD Radeon RX 6950 XT": 256,
"AMD Radeon RX 6900 XT": 256,
"AMD Radeon RX 6800 XT": 256,
"AMD Radeon RX 6800": 256,
"AMD Radeon RX 6750 XT": 192,
"AMD Radeon RX 6700 XT": 192,
"AMD Radeon RX 6700": 160,
"AMD Radeon RX 6650 XT": 128,
"AMD Radeon RX 6600 XT": 128,
"AMD Radeon RX 6600": 128,
"AMD Radeon RX 6500 XT": 64,
"AMD Radeon RX 6400": 64,
}
class DeviceProperties:
PROPERTIES_OVERRIDE = {
# sometimes gcnArchName contains device name ("AMD Radeon RX ..."), not architecture name ("gfx...")
"gcnArchName": "UNKNOWN ARCHITECTURE",
}
internal: torch._C._CudaDeviceProperties
def __init__(self, props: torch._C._CudaDeviceProperties):
self.internal = props
def __getattr__(self, name):
if name in DeviceProperties.PROPERTIES_OVERRIDE:
return DeviceProperties.PROPERTIES_OVERRIDE[name]
return getattr(self.internal, name)
__get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access
def torch_cuda__get_device_properties(device):
return DeviceProperties(__get_device_properties(device))
_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream # pylint: disable=protected-access
def torch__C__cuda_getCurrentRawStream(device):
from modules import zluda
return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device))
def get_default_agent_name():
if shared.devices.backend == "rocm":
device = shared.devices.get_optimal_device()
return getattr(torch.cuda.get_device_properties(device), "gcnArchName", None)
else:
from modules import zluda
if zluda.default_agent is None:
return None
return zluda.default_agent.name
def apply_triton_patches():
arch_name = get_default_agent_name()
if arch_name is not None:
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = arch_name
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
if shared.devices.backend == "zluda":
torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access
torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access
# Triton
try:
import triton
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
def triton_runtime_driver_active_utils_get_device_properties(device):
props = _get_device_properties(device)
name = torch.cuda.get_device_name()
if shared.devices.has_zluda():
name = name[:-8]
if props["mem_bus_width"] == 0: # Windows HIP SDK bug
if name in MEM_BUS_WIDTH:
props["mem_bus_width"] = MEM_BUS_WIDTH[name]
else:
props["mem_bus_width"] = 128
shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".')
return props
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
if 'Flash attention' in shared.opts.sdp_options:
from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
if scale is None:
scale = query.shape[-1] ** (-0.5)
head_size_og = query.size(3)
if head_size_og % 8 != 0:
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
query = query.transpose(1, 2)
out_padded = torch.zeros_like(query)
interface_fa.fwd(
query,
key.transpose(1, 2),
value.transpose(1, 2),
out_padded,
dropout_p,
scale,
is_causal,
)
return out_padded[..., :head_size_og].transpose(1, 2)
else:
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
shared.log.debug('Torch attention: type="triton flash attention"')
except Exception:
pass

View File

@ -1281,9 +1281,13 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device)
history = history.History()
if devices.backend == "directml":
directml_do_hijack()
elif devices.backend == "zluda":
from modules.zluda import initialize_zluda
initialize_zluda()
elif sys.platform == "win32" and (devices.backend == "zluda" or devices.backend == "rocm"):
from modules.rocm_triton_windows import apply_triton_patches
apply_triton_patches()
if devices.backend == "zluda":
from modules.zluda import initialize_zluda
initialize_zluda()
try:
log.info(f'Device: {print_dict(devices.get_gpu_info())}')
except Exception as ex:

View File

@ -26,13 +26,9 @@ def test(device: DeviceLikeType) -> Union[Exception, None]:
def initialize_zluda():
shared.cmd_opts.device_id = None
device = devices.get_optimal_device()
if not devices.cuda_ok or not devices.has_zluda() or PLATFORM != "win32":
if not devices.cuda_ok or not devices.has_zluda():
return
from modules.zluda_hijacks import do_hijack
do_hijack()
torch.backends.cudnn.enabled = zluda_installer.MIOpen_enabled
if hasattr(torch.backends.cuda, "enable_cudnn_sdp"):
if not zluda_installer.MIOpen_enabled:
@ -51,6 +47,7 @@ def initialize_zluda():
if shared.opts.onnx_execution_provider == ExecutionProvider.CUDA:
shared.opts.onnx_execution_provider = ExecutionProvider.CPU
device = devices.get_optimal_device()
result = test(device)
if result is not None:
shared.log.warning(f'ZLUDA device failed to pass basic operation test: index={device.index}, device_name={torch.cuda.get_device_name(device)}')

View File

@ -1,115 +0,0 @@
from functools import wraps
import torch
import torch._dynamo.device_interface
from modules import shared, zluda # pylint: disable=unused-import
MEM_BUS_WIDTH = {
"AMD Radeon RX 9070 XT": 256,
"AMD Radeon RX 9070": 256,
"AMD Radeon RX 9060 XT": 192,
"AMD Radeon RX 7900 XTX": 384,
"AMD Radeon RX 7900 XT": 320,
"AMD Radeon RX 7900 GRE": 256,
"AMD Radeon RX 7800 XT": 256,
"AMD Radeon RX 7700 XT": 192,
"AMD Radeon RX 7700": 192,
"AMD Radeon RX 7650 GRE": 128,
"AMD Radeon RX 7600 XT": 128,
"AMD Radeon RX 7600": 128,
"AMD Radeon RX 7500 XT": 96,
"AMD Radeon RX 6950 XT": 256,
"AMD Radeon RX 6900 XT": 256,
"AMD Radeon RX 6800 XT": 256,
"AMD Radeon RX 6800": 256,
"AMD Radeon RX 6750 XT": 192,
"AMD Radeon RX 6700 XT": 192,
"AMD Radeon RX 6700": 160,
"AMD Radeon RX 6650 XT": 128,
"AMD Radeon RX 6600 XT": 128,
"AMD Radeon RX 6600": 128,
"AMD Radeon RX 6500 XT": 64,
"AMD Radeon RX 6400": 64,
}
class DeviceProperties:
PROPERTIES_OVERRIDE = {
# sometimes gcnArchName contains device name ("AMD Radeon RX ..."), not architecture name ("gfx...")
"gcnArchName": "UNKNOWN ARCHITECTURE",
}
internal: torch._C._CudaDeviceProperties
def __init__(self, props: torch._C._CudaDeviceProperties):
self.internal = props
def __getattr__(self, name):
if name in DeviceProperties.PROPERTIES_OVERRIDE:
return DeviceProperties.PROPERTIES_OVERRIDE[name]
return getattr(self.internal, name)
__get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access
def torch_cuda__get_device_properties(device):
return DeviceProperties(__get_device_properties(device))
_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream # pylint: disable=protected-access
def torch__C__cuda_getCurrentRawStream(device):
return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device))
def do_hijack():
if zluda.default_agent is not None:
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
torch._C._cuda_getCurrentRawStream = torch__C__cuda_getCurrentRawStream # pylint: disable=protected-access
torch._dynamo.device_interface.CudaInterface.get_raw_stream = staticmethod(torch__C__cuda_getCurrentRawStream) # pylint: disable=protected-access
# Triton
try:
import triton
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
def triton_runtime_driver_active_utils_get_device_properties(device):
props = _get_device_properties(device)
name = torch.cuda.get_device_name()[:-8]
if props["mem_bus_width"] == 0: # Windows HIP SDK bug
if name in MEM_BUS_WIDTH:
props["mem_bus_width"] = MEM_BUS_WIDTH[name]
else:
props["mem_bus_width"] = 128
shared.log.warning(f'[TRITON] defaulting mem_bus_width=128 for device "{name}".')
return props
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
if 'Flash attention' in shared.opts.sdp_options:
from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
if scale is None:
scale = query.shape[-1] ** (-0.5)
head_size_og = query.size(3)
if head_size_og % 8 != 0:
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
query = query.transpose(1, 2)
out_padded = torch.zeros_like(query)
interface_fa.fwd(
query,
key.transpose(1, 2),
value.transpose(1, 2),
out_padded,
dropout_p,
scale,
is_causal,
)
return out_padded[..., :head_size_og].transpose(1, 2)
else:
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
shared.log.debug('Torch attention: type="triton flash attention"')
except Exception:
pass