mirror of https://github.com/vladmandic/automatic
check hipblaslt availability in windows
parent
2f3e7d2713
commit
e43d8e9448
16
installer.py
16
installer.py
|
|
@ -569,6 +569,7 @@ def install_rocm_zluda():
|
|||
msg += f', using agent {device.name}'
|
||||
log.info(msg)
|
||||
torch_command = ''
|
||||
|
||||
if sys.platform == "win32":
|
||||
# TODO install: enable ROCm for windows when available
|
||||
|
||||
|
|
@ -584,17 +585,20 @@ def install_rocm_zluda():
|
|||
try:
|
||||
if args.reinstall:
|
||||
zluda_installer.uninstall()
|
||||
zluda_path = zluda_installer.get_path()
|
||||
zluda_installer.install(zluda_path)
|
||||
zluda_installer.make_copy(zluda_path)
|
||||
zluda_installer.install()
|
||||
except Exception as e:
|
||||
error = e
|
||||
log.warning(f'Failed to install ZLUDA: {e}')
|
||||
|
||||
if error is None:
|
||||
try:
|
||||
zluda_installer.load(zluda_path)
|
||||
if device is not None and zluda_installer.get_blaslt_enabled():
|
||||
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
|
||||
zluda_installer.set_blaslt_enabled(device.blaslt_supported)
|
||||
zluda_installer.make_copy()
|
||||
zluda_installer.load()
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118')
|
||||
log.info(f'Using ZLUDA in {zluda_path}')
|
||||
log.info(f'Using ZLUDA in {zluda_installer.path}')
|
||||
except Exception as e:
|
||||
error = e
|
||||
log.warning(f'Failed to load ZLUDA: {e}')
|
||||
|
|
@ -631,7 +635,7 @@ def install_rocm_zluda():
|
|||
#elif not args.experimental:
|
||||
# uninstall('flash-attn')
|
||||
|
||||
if device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
|
||||
if device is not None and rocm.version != "6.2" and rocm.get_blaslt_enabled():
|
||||
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
|
||||
rocm.set_blaslt_enabled(device.blaslt_supported)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,6 @@ from typing import Union, List
|
|||
from enum import Enum
|
||||
|
||||
|
||||
HIPBLASLT_TENSILE_LIBPATH = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", None if sys.platform == "win32" # not available
|
||||
else "/opt/rocm/lib/hipblaslt/library")
|
||||
|
||||
|
||||
def resolve_link(path_: str) -> str:
|
||||
if not os.path.islink(path_):
|
||||
return path_
|
||||
|
|
@ -55,8 +51,7 @@ class Agent:
|
|||
gfx_version: int
|
||||
arch: MicroArchitecture
|
||||
is_apu: bool
|
||||
if sys.platform != "win32":
|
||||
blaslt_supported: bool
|
||||
blaslt_supported: bool
|
||||
|
||||
@staticmethod
|
||||
def parse_gfx_version(name: str) -> int:
|
||||
|
|
@ -83,8 +78,7 @@ class Agent:
|
|||
else:
|
||||
self.arch = MicroArchitecture.GCN
|
||||
self.is_apu = (self.gfx_version & 0xFFF0 == 0x1150) or self.gfx_version in (0x801, 0x902, 0x90c, 0x1013, 0x1033, 0x1035, 0x1036, 0x1103,)
|
||||
if sys.platform != "win32":
|
||||
self.blaslt_supported = os.path.exists(os.path.join(HIPBLASLT_TENSILE_LIBPATH, f"extop_{name}.co"))
|
||||
self.blaslt_supported = os.path.exists(os.path.join(blaslt_tensile_libpath, f"Kernels.so-000-{name}.hsaco" if sys.platform == "win32" else f"extop_{name}.co"))
|
||||
|
||||
def get_gfx_version(self) -> Union[str, None]:
|
||||
if self.gfx_version >= 0x1200:
|
||||
|
|
@ -163,6 +157,7 @@ if sys.platform == "win32":
|
|||
return [Agent(x.split(' ')[-1].strip()) for x in spawn("hipinfo", cwd=os.path.join(path, 'bin')).split("\n") if x.startswith('gcnArchName:')]
|
||||
|
||||
is_wsl: bool = False
|
||||
version_torch = None
|
||||
else:
|
||||
def find() -> Union[str, None]:
|
||||
rocm_path = shutil.which("hipconfig")
|
||||
|
|
@ -199,12 +194,12 @@ else:
|
|||
def set_blaslt_enabled(enabled: bool) -> None:
|
||||
if enabled:
|
||||
load_library_global("/opt/rocm/lib/libhipblaslt.so") # Preload hipBLASLt.
|
||||
os.environ["HIPBLASLT_TENSILE_LIBPATH"] = HIPBLASLT_TENSILE_LIBPATH
|
||||
os.environ["HIPBLASLT_TENSILE_LIBPATH"] = blaslt_tensile_libpath
|
||||
else:
|
||||
os.environ["TORCH_BLAS_PREFER_HIPBLASLT"] = "0"
|
||||
|
||||
def get_blaslt_enabled() -> bool:
|
||||
return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1")))
|
||||
return version == version_torch and bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1")))
|
||||
|
||||
def get_flash_attention_command(agent: Agent):
|
||||
if os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE":
|
||||
|
|
@ -215,10 +210,11 @@ else:
|
|||
return os.environ.get("FLASH_ATTENTION_PACKAGE", default)
|
||||
|
||||
is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None
|
||||
version_torch = get_version_torch()
|
||||
path = find()
|
||||
blaslt_tensile_libpath = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", os.path.join(path, "bin" if sys.platform == "win32" else "lib", "hipblaslt", "library"))
|
||||
is_installed = False
|
||||
version = None
|
||||
version_torch = get_version_torch()
|
||||
if path is not None:
|
||||
is_installed = True
|
||||
version = get_version()
|
||||
|
|
|
|||
|
|
@ -36,5 +36,5 @@ def do_hijack():
|
|||
torch.fft.ifftn = fft_ifftn
|
||||
torch.fft.rfftn = fft_rfftn
|
||||
|
||||
if not zluda_installer.experimental_hipBLASLt_support:
|
||||
if not zluda_installer.get_blaslt_enabled():
|
||||
torch.jit.script = jit_script
|
||||
|
|
|
|||
|
|
@ -16,12 +16,10 @@ DLL_MAPPING = {
|
|||
}
|
||||
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll']
|
||||
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)
|
||||
experimental_hipBLASLt_support: bool = False
|
||||
|
||||
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
|
||||
default_agent: Union[rocm.Agent, None] = None
|
||||
|
||||
|
||||
def get_path() -> str:
|
||||
return os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
|
||||
hipBLASLt_enabled = os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll")) and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(path, 'cublasLt.dll'))
|
||||
|
||||
|
||||
def set_default_agent(agent: rocm.Agent):
|
||||
|
|
@ -29,9 +27,8 @@ def set_default_agent(agent: rocm.Agent):
|
|||
default_agent = agent
|
||||
|
||||
|
||||
def install(zluda_path: os.PathLike) -> None:
|
||||
if os.path.exists(zluda_path):
|
||||
__initialize(zluda_path)
|
||||
def install() -> None:
|
||||
if os.path.exists(path):
|
||||
return
|
||||
|
||||
platform = "windows"
|
||||
|
|
@ -46,7 +43,6 @@ def install(zluda_path: os.PathLike) -> None:
|
|||
info.filename = os.path.basename(info.filename)
|
||||
archive.extract(info, '.zluda')
|
||||
os.remove('_zluda')
|
||||
__initialize(zluda_path)
|
||||
|
||||
|
||||
def uninstall() -> None:
|
||||
|
|
@ -54,27 +50,45 @@ def uninstall() -> None:
|
|||
shutil.rmtree('.zluda')
|
||||
|
||||
|
||||
def make_copy(zluda_path: os.PathLike) -> None:
|
||||
__initialize(zluda_path)
|
||||
def set_blaslt_enabled(enabled: bool):
|
||||
global hipBLASLt_enabled # pylint: disable=global-statement
|
||||
hipBLASLt_enabled = enabled
|
||||
|
||||
|
||||
def get_blaslt_enabled() -> bool:
|
||||
return hipBLASLt_enabled
|
||||
|
||||
|
||||
def link_or_copy(src: os.PathLike, dst: os.PathLike):
|
||||
try:
|
||||
os.link(src, dst)
|
||||
except Exception:
|
||||
shutil.copyfile(src, dst)
|
||||
|
||||
|
||||
def make_copy() -> None:
|
||||
for k, v in DLL_MAPPING.items():
|
||||
if not os.path.exists(os.path.join(zluda_path, v)):
|
||||
try:
|
||||
os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
|
||||
except Exception:
|
||||
shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
|
||||
if not os.path.exists(os.path.join(path, v)):
|
||||
link_or_copy(os.path.join(path, k), os.path.join(path, v))
|
||||
|
||||
if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')):
|
||||
link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll'))
|
||||
|
||||
|
||||
def load(zluda_path: os.PathLike) -> None:
|
||||
def load() -> None:
|
||||
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
|
||||
os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll")
|
||||
|
||||
for v in HIPSDK_TARGETS:
|
||||
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
|
||||
for v in ZLUDA_TARGETS:
|
||||
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
|
||||
ctypes.windll.LoadLibrary(os.path.join(path, v))
|
||||
for v in DLL_MAPPING.values():
|
||||
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
|
||||
ctypes.windll.LoadLibrary(os.path.join(path, v))
|
||||
|
||||
if hipBLASLt_enabled:
|
||||
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'hipblaslt.dll'))
|
||||
ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll'))
|
||||
|
||||
def conceal():
|
||||
import torch # pylint: disable=unused-import
|
||||
|
|
@ -94,18 +108,7 @@ def load(zluda_path: os.PathLike) -> None:
|
|||
def get_default_torch_version(agent: Optional[rocm.Agent]) -> str:
|
||||
if agent is not None:
|
||||
if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,):
|
||||
return "2.4.1" if experimental_hipBLASLt_support else "2.3.1"
|
||||
return "2.4.1" if hipBLASLt_enabled else "2.3.1"
|
||||
elif agent.arch == rocm.MicroArchitecture.GCN:
|
||||
return "2.2.1"
|
||||
return "2.4.1" if experimental_hipBLASLt_support else "2.3.1"
|
||||
|
||||
|
||||
def __initialize(zluda_path: os.PathLike):
|
||||
global experimental_hipBLASLt_support # pylint: disable=global-statement
|
||||
experimental_hipBLASLt_support = os.path.exists(os.path.join(zluda_path, 'cublasLt.dll'))
|
||||
|
||||
if experimental_hipBLASLt_support:
|
||||
HIPSDK_TARGETS.append('hipblaslt.dll')
|
||||
DLL_MAPPING['cublasLt.dll'] = 'cublasLt64_11.dll'
|
||||
else:
|
||||
HIPSDK_TARGETS.append(f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll')
|
||||
return "2.4.1" if hipBLASLt_enabled else "2.3.1"
|
||||
|
|
|
|||
Loading…
Reference in New Issue