fix onnxruntime install
parent
e0f5194815
commit
2c65fb1af1
|
|
@ -205,7 +205,7 @@ def setup_windows_bitsandbytes():
|
|||
bnb_package = "bitsandbytes==0.46.0"
|
||||
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
|
||||
|
||||
installed_bnb = is_installed("bitsandbytes") # don't check version here
|
||||
installed_bnb = is_installed("bitsandbytes") # don't check version here
|
||||
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0
|
||||
|
||||
if not installed_bnb or not bnb_cuda_setup:
|
||||
|
|
@ -214,20 +214,10 @@ def setup_windows_bitsandbytes():
|
|||
run_pip(f"install {bnb_package}", bnb_package, live=True)
|
||||
|
||||
|
||||
def setup_onnxruntime():
|
||||
onnx_version = "1.18.1"
|
||||
index_url = None
|
||||
|
||||
try:
|
||||
import torch
|
||||
torch_version = torch.__version__
|
||||
if "cu12" in torch_version:
|
||||
# for cuda 12
|
||||
onnx_version = f"1.18.1"
|
||||
index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
|
||||
except ImportError:
|
||||
log.error("torch not found")
|
||||
|
||||
def setup_onnxruntime(
|
||||
onnx_version: Optional[str] = None,
|
||||
index_url: Optional[str] = None
|
||||
):
|
||||
if sys.platform == "linux":
|
||||
libc_ver = platform.libc_ver()
|
||||
if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27":
|
||||
|
|
@ -235,24 +225,39 @@ def setup_onnxruntime():
|
|||
|
||||
onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version)
|
||||
|
||||
if not is_installed(f"onnxruntime-gpu=={onnx_version}"):
|
||||
if onnx_version and not is_installed(f"onnxruntime-gpu=={onnx_version}"):
|
||||
log.info("uninstalling wrong onnxruntime version")
|
||||
# run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True)
|
||||
run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True)
|
||||
run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True)
|
||||
|
||||
if not is_installed(f"onnxruntime-gpu"):
|
||||
log.info(f"installing onnxruntime")
|
||||
run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True)
|
||||
if index_url:
|
||||
run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True)
|
||||
else:
|
||||
run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True)
|
||||
pip_install("onnxruntime", onnx_version, index_url=index_url, live=True)
|
||||
pip_install("onnxruntime-gpu", onnx_version, index_url=index_url, live=True)
|
||||
|
||||
|
||||
def run_pip(command, desc=None, live=False):
|
||||
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def pip_install(package: str, version: Optional[str] = None, index_url: Optional[str] = None, live: bool = True):
|
||||
"""
|
||||
Install a package using pip.
|
||||
:param package: The name of the package to install.
|
||||
:param version: The version of the package to install (optional).
|
||||
:param index_url: The index URL to use for installing the package (optional).
|
||||
"""
|
||||
if version:
|
||||
package = f"{package}=={version}"
|
||||
|
||||
command = f"install {package}"
|
||||
|
||||
if index_url:
|
||||
command = f"{command} -i {index_url}"
|
||||
|
||||
run_pip(command, desc=f"Installing {package}", live=live)
|
||||
|
||||
|
||||
def check_run(file: str) -> bool:
|
||||
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
|
||||
log.info(result.stdout.decode("utf-8").strip())
|
||||
|
|
@ -275,7 +280,7 @@ def network_gfw_test(timeout=3):
|
|||
return False
|
||||
|
||||
|
||||
def prepare_environment(disable_auto_mirror: bool = True):
|
||||
def prepare_environment(disable_auto_mirror: bool = True, prepare_onnxruntime: bool = True):
|
||||
if sys.platform == "win32":
|
||||
# disable triton on windows
|
||||
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
|
||||
|
|
@ -304,8 +309,8 @@ def prepare_environment(disable_auto_mirror: bool = True):
|
|||
validate_requirements("requirements.txt")
|
||||
setup_windows_bitsandbytes()
|
||||
|
||||
# if not skip_prepare_onnxruntime:
|
||||
# setup_onnxruntime()
|
||||
if prepare_onnxruntime:
|
||||
setup_onnxruntime()
|
||||
|
||||
|
||||
def catch_exception(f):
|
||||
|
|
|
|||
Loading…
Reference in New Issue