mirror of https://github.com/vladmandic/automatic
fix zluda-python.py
closes #3857 Co-authored-by: ackhack <christoph.fuerbacher@gmail.com>pull/3866/head
parent
68f3c203b1
commit
69ec66f58a
|
|
@ -28,7 +28,6 @@ if __name__ == '__main__':
|
|||
|
||||
from modules import zluda_installer
|
||||
zluda_installer.install()
|
||||
zluda_installer.make_copy()
|
||||
zluda_installer.load()
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -649,7 +649,6 @@ def install_rocm_zluda():
|
|||
if device is not None and zluda_installer.get_blaslt_enabled():
|
||||
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
|
||||
zluda_installer.set_blaslt_enabled(device.blaslt_supported)
|
||||
zluda_installer.make_copy()
|
||||
zluda_installer.load()
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ def initialize_zluda():
|
|||
from modules.zluda_hijacks import do_hijack
|
||||
do_hijack()
|
||||
|
||||
torch.backends.cudnn.enabled = zluda_installer.MIOpen_available
|
||||
if not zluda_installer.MIOpen_available:
|
||||
torch.backends.cudnn.enabled = zluda_installer.MIOpen_enabled
|
||||
if not zluda_installer.MIOpen_enabled:
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
torch.backends.cuda.enable_cudnn_sdp = do_nothing
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ DLL_MAPPING = {
|
|||
}
|
||||
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',]
|
||||
|
||||
hipBLASLt_available = False
|
||||
MIOpen_available = False
|
||||
MIOpen_enabled = False
|
||||
|
||||
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
|
||||
default_agent: Union[rocm.Agent, None] = None
|
||||
|
|
@ -65,36 +64,16 @@ core = None
|
|||
ml = None
|
||||
|
||||
|
||||
def load_core_modules():
|
||||
global core, ml # pylint: disable=global-statement
|
||||
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
|
||||
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
|
||||
|
||||
|
||||
def set_default_agent(agent: rocm.Agent):
|
||||
global default_agent # pylint: disable=global-statement
|
||||
default_agent = agent
|
||||
|
||||
is_nightly = False
|
||||
try:
|
||||
load_core_modules()
|
||||
is_nightly = core.get_nightly_flag() == 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
global hipBLASLt_available, hipBLASLt_enabled # pylint: disable=global-statement
|
||||
hipBLASLt_available = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath)
|
||||
hipBLASLt_enabled = hipBLASLt_available and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
|
||||
|
||||
global MIOpen_available # pylint: disable=global-statement
|
||||
MIOpen_available = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
|
||||
|
||||
|
||||
def is_reinstall_needed() -> bool: # ZLUDA<3.8.7
|
||||
return not os.path.exists(os.path.join(path, 'cufftw.dll'))
|
||||
|
||||
|
||||
def install() -> None:
|
||||
def install():
|
||||
if os.path.exists(path):
|
||||
return
|
||||
|
||||
|
|
@ -115,7 +94,7 @@ def install() -> None:
|
|||
os.remove('_zluda')
|
||||
|
||||
|
||||
def uninstall() -> None:
|
||||
def uninstall():
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
|
@ -139,7 +118,16 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike):
|
|||
shutil.copyfile(src, dst)
|
||||
|
||||
|
||||
def make_copy() -> None:
|
||||
def load():
|
||||
global core, ml # pylint: disable=global-statement
|
||||
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
|
||||
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
|
||||
|
||||
is_nightly = core.get_nightly_flag() == 1
|
||||
hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
|
||||
global MIOpen_enabled # pylint: disable=global-statement
|
||||
MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
|
||||
|
||||
for k, v in DLL_MAPPING.items():
|
||||
if not os.path.exists(os.path.join(path, v)):
|
||||
link_or_copy(os.path.join(path, k), os.path.join(path, v))
|
||||
|
|
@ -147,17 +135,14 @@ def make_copy() -> None:
|
|||
if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')):
|
||||
link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll'))
|
||||
|
||||
if MIOpen_available and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
|
||||
if MIOpen_enabled and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
|
||||
link_or_copy(os.path.join(path, 'cudnn.dll'), os.path.join(path, 'cudnn64_9.dll'))
|
||||
|
||||
|
||||
def load() -> None:
|
||||
log.info(f"ZLUDA load: path='{path}' nightly={bool(core.get_nightly_flag())}")
|
||||
|
||||
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
|
||||
os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll")
|
||||
|
||||
load_core_modules()
|
||||
for v in HIPSDK_TARGETS:
|
||||
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
|
||||
for v in DLL_MAPPING.values():
|
||||
|
|
@ -170,7 +155,7 @@ def load() -> None:
|
|||
else:
|
||||
os.environ["DISABLE_ADDMM_CUDA_LT"] = "1"
|
||||
|
||||
if MIOpen_available:
|
||||
if MIOpen_enabled:
|
||||
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'MIOpen.dll'))
|
||||
ctypes.windll.LoadLibrary(os.path.join(path, 'cudnn64_9.dll'))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue