fix zluda-python.py

closes #3857

Co-authored-by: ackhack <christoph.fuerbacher@gmail.com>
pull/3866/head
Seunghoon Lee 2025-04-05 22:26:07 +09:00
parent 68f3c203b1
commit 69ec66f58a
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
4 changed files with 17 additions and 34 deletions

View File

@ -28,7 +28,6 @@ if __name__ == '__main__':
from modules import zluda_installer
zluda_installer.install()
zluda_installer.make_copy()
zluda_installer.load()
import torch

View File

@ -649,7 +649,6 @@ def install_rocm_zluda():
if device is not None and zluda_installer.get_blaslt_enabled():
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
zluda_installer.set_blaslt_enabled(device.blaslt_supported)
zluda_installer.make_copy()
zluda_installer.load()
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
except Exception as e:

View File

@ -33,8 +33,8 @@ def initialize_zluda():
from modules.zluda_hijacks import do_hijack
do_hijack()
torch.backends.cudnn.enabled = zluda_installer.MIOpen_available
if not zluda_installer.MIOpen_available:
torch.backends.cudnn.enabled = zluda_installer.MIOpen_enabled
if not zluda_installer.MIOpen_enabled:
torch.backends.cuda.enable_cudnn_sdp(False)
torch.backends.cuda.enable_cudnn_sdp = do_nothing
torch.backends.cuda.enable_flash_sdp(False)

View File

@ -19,8 +19,7 @@ DLL_MAPPING = {
}
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',]
hipBLASLt_available = False
MIOpen_available = False
MIOpen_enabled = False
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
default_agent: Union[rocm.Agent, None] = None
@ -65,36 +64,16 @@ core = None
ml = None
def load_core_modules():
global core, ml # pylint: disable=global-statement
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
def set_default_agent(agent: rocm.Agent):
global default_agent # pylint: disable=global-statement
default_agent = agent
is_nightly = False
try:
load_core_modules()
is_nightly = core.get_nightly_flag() == 1
except Exception:
pass
global hipBLASLt_available, hipBLASLt_enabled # pylint: disable=global-statement
hipBLASLt_available = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath)
hipBLASLt_enabled = hipBLASLt_available and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
global MIOpen_available # pylint: disable=global-statement
MIOpen_available = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
def is_reinstall_needed() -> bool: # ZLUDA<3.8.7
return not os.path.exists(os.path.join(path, 'cufftw.dll'))
def install() -> None:
def install():
if os.path.exists(path):
return
@ -115,7 +94,7 @@ def install() -> None:
os.remove('_zluda')
def uninstall() -> None:
def uninstall():
if os.path.exists(path):
shutil.rmtree(path)
@ -139,7 +118,16 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike):
shutil.copyfile(src, dst)
def make_copy() -> None:
def load():
global core, ml # pylint: disable=global-statement
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
is_nightly = core.get_nightly_flag() == 1
hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
global MIOpen_enabled # pylint: disable=global-statement
MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
for k, v in DLL_MAPPING.items():
if not os.path.exists(os.path.join(path, v)):
link_or_copy(os.path.join(path, k), os.path.join(path, v))
@ -147,17 +135,14 @@ def make_copy() -> None:
if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')):
link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll'))
if MIOpen_available and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
if MIOpen_enabled and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
link_or_copy(os.path.join(path, 'cudnn.dll'), os.path.join(path, 'cudnn64_9.dll'))
def load() -> None:
log.info(f"ZLUDA load: path='{path}' nightly={bool(core.get_nightly_flag())}")
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll")
load_core_modules()
for v in HIPSDK_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
for v in DLL_MAPPING.values():
@ -170,7 +155,7 @@ def load() -> None:
else:
os.environ["DISABLE_ADDMM_CUDA_LT"] = "1"
if MIOpen_available:
if MIOpen_enabled:
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'MIOpen.dll'))
ctypes.windll.LoadLibrary(os.path.join(path, 'cudnn64_9.dll'))