mirror of https://github.com/vladmandic/automatic
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
import os
|
|
import sys
|
|
import ctypes
|
|
import shutil
|
|
import zipfile
|
|
import urllib.request
|
|
from typing import Optional
|
|
from modules import rocm
|
|
|
|
|
|
DLL_MAPPING = {
|
|
'cublas.dll': 'cublas64_11.dll',
|
|
'cusparse.dll': 'cusparse64_11.dll',
|
|
'nvrtc.dll': 'nvrtc64_112_0.dll',
|
|
}
|
|
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll']
|
|
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)
|
|
|
|
|
|
def get_path() -> str:
|
|
return os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
|
|
|
|
|
|
def install(zluda_path: os.PathLike) -> None:
|
|
if os.path.exists(zluda_path):
|
|
return
|
|
|
|
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{os.environ.get("ZLUDA_HASH", "c0804ca624963aab420cb418412b1c7fbae3454b")}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda')
|
|
with zipfile.ZipFile('_zluda', 'r') as archive:
|
|
infos = archive.infolist()
|
|
for info in infos:
|
|
if not info.is_dir():
|
|
info.filename = os.path.basename(info.filename)
|
|
archive.extract(info, '.zluda')
|
|
os.remove('_zluda')
|
|
|
|
|
|
def uninstall() -> None:
|
|
if os.path.exists('.zluda'):
|
|
shutil.rmtree('.zluda')
|
|
|
|
|
|
def make_copy(zluda_path: os.PathLike) -> None:
|
|
for k, v in DLL_MAPPING.items():
|
|
if not os.path.exists(os.path.join(zluda_path, v)):
|
|
try:
|
|
os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
|
|
except Exception:
|
|
shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
|
|
|
|
|
|
def load(zluda_path: os.PathLike) -> None:
|
|
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
|
|
|
|
for v in HIPSDK_TARGETS:
|
|
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
|
|
for v in ZLUDA_TARGETS:
|
|
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
|
|
for v in DLL_MAPPING.values():
|
|
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
|
|
|
|
def conceal():
|
|
import torch # pylint: disable=unused-import
|
|
platform = sys.platform
|
|
sys.platform = ""
|
|
from torch.utils import cpp_extension
|
|
sys.platform = platform
|
|
cpp_extension.IS_WINDOWS = platform == "win32"
|
|
cpp_extension.IS_MACOS = False
|
|
cpp_extension.IS_LINUX = platform.startswith('linux')
|
|
def _join_rocm_home(*paths) -> str:
|
|
return os.path.join(cpp_extension.ROCM_HOME, *paths)
|
|
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
|
|
rocm.conceal = conceal
|
|
|
|
|
|
def get_default_torch_version(agent: Optional[rocm.Agent]) -> str:
|
|
if agent is not None:
|
|
if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,):
|
|
return "2.3.1"
|
|
elif agent.arch == rocm.MicroArchitecture.GCN:
|
|
return "2.2.1"
|
|
return "2.3.1"
|