diff --git a/mikazuki/launch_utils.py b/mikazuki/launch_utils.py index 42b3f01..282db7a 100644 --- a/mikazuki/launch_utils.py +++ b/mikazuki/launch_utils.py @@ -202,10 +202,10 @@ def setup_windows_bitsandbytes(): return # bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple") - bnb_package = "bitsandbytes==0.44.0" + bnb_package = "bitsandbytes==0.45.3" bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes") - installed_bnb = is_installed(bnb_package) + installed_bnb = is_installed("bitsandbytes") # don't check version here bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0 if not installed_bnb or not bnb_cuda_setup: @@ -259,6 +259,22 @@ def check_run(file: str) -> bool: return result.returncode == 0 +def network_gfw_test(timeout=3): + try: + import requests + # requests will auto detect system proxies + response = requests.get("https://www.google.com", timeout=timeout) + if response.status_code == 200: + log.info("Network test passed") + return True + else: + log.error(f"Network test failed: {response.status_code}") + return False + except requests.exceptions.RequestException as e: + log.error(f"Network test failed: {e}") + return False + + def prepare_environment(disable_auto_mirror: bool = True): if sys.platform == "win32": # disable triton on windows @@ -269,8 +285,8 @@ def prepare_environment(disable_auto_mirror: bool = True): os.environ["PYTHONWARNINGS"] = "ignore::UserWarning" os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" - if not disable_auto_mirror and locale.getdefaultlocale()[0] == "zh_CN": - log.info("detected locale zh_CN, use pip & huggingface mirrors") + if not disable_auto_mirror and not network_gfw_test(): + log.info("use pip & huggingface mirrors") os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html") os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple") os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")