fix onnxruntime install

pull/720/head
Akegarasu 2025-08-21 20:18:51 +08:00
parent e0f5194815
commit 2c65fb1af1
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
1 changed files with 30 additions and 25 deletions

View File

@ -205,7 +205,7 @@ def setup_windows_bitsandbytes():
bnb_package = "bitsandbytes==0.46.0"
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
installed_bnb = is_installed("bitsandbytes") # don't check version here
installed_bnb = is_installed("bitsandbytes") # don't check version here
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0
if not installed_bnb or not bnb_cuda_setup:
@ -214,20 +214,10 @@ def setup_windows_bitsandbytes():
run_pip(f"install {bnb_package}", bnb_package, live=True)
def setup_onnxruntime():
onnx_version = "1.18.1"
index_url = None
try:
import torch
torch_version = torch.__version__
if "cu12" in torch_version:
# for cuda 12
onnx_version = f"1.18.1"
index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
except ImportError:
log.error("torch not found")
def setup_onnxruntime(
onnx_version: Optional[str] = None,
index_url: Optional[str] = None
):
if sys.platform == "linux":
libc_ver = platform.libc_ver()
if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27":
@ -235,24 +225,39 @@ def setup_onnxruntime():
onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version)
if not is_installed(f"onnxruntime-gpu=={onnx_version}"):
if onnx_version and not is_installed(f"onnxruntime-gpu=={onnx_version}"):
log.info("uninstalling wrong onnxruntime version")
# run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True)
run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True)
run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True)
if not is_installed(f"onnxruntime-gpu"):
log.info(f"installing onnxruntime")
run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True)
if index_url:
run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True)
else:
run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True)
pip_install("onnxruntime", onnx_version, index_url=index_url, live=True)
pip_install("onnxruntime-gpu", onnx_version, index_url=index_url, live=True)
def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def pip_install(package: str, version: Optional[str] = None, index_url: Optional[str] = None, live: bool = True):
"""
Install a package using pip.
:param package: The name of the package to install.
:param version: The version of the package to install (optional).
:param index_url: The index URL to use for installing the package (optional).
"""
if version:
package = f"{package}=={version}"
command = f"install {package}"
if index_url:
command = f"{command} -i {index_url}"
run_pip(command, desc=f"Installing {package}", live=live)
def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip())
@ -275,7 +280,7 @@ def network_gfw_test(timeout=3):
return False
def prepare_environment(disable_auto_mirror: bool = True):
def prepare_environment(disable_auto_mirror: bool = True, prepare_onnxruntime: bool = True):
if sys.platform == "win32":
# disable triton on windows
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
@ -304,8 +309,8 @@ def prepare_environment(disable_auto_mirror: bool = True):
validate_requirements("requirements.txt")
setup_windows_bitsandbytes()
# if not skip_prepare_onnxruntime:
# setup_onnxruntime()
if prepare_onnxruntime:
setup_onnxruntime()
def catch_exception(f):