164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
import locale
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import sysconfig
|
|
from typing import List
|
|
|
|
import pkg_resources
|
|
|
|
from mikazuki.log import log
|
|
from mikazuki.utils import run_pip
|
|
|
|
|
|
def smart_pip_mirror():
|
|
if locale.getdefaultlocale()[0] == "zh_CN":
|
|
log.info("detected locale zh_CN, use pip mirrors")
|
|
os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html")
|
|
os.environ.setdefault("PIP_INDEX_URL", "https://mirror.baidu.com/pypi/simple")
|
|
|
|
|
|
def find_windows_git():
|
|
possible_paths = ["git\\bin\\git.exe", "git\\cmd\\git.exe", "Git\\mingw64\\libexec\\git-core\\git.exe"]
|
|
for path in possible_paths:
|
|
if os.path.exists(path):
|
|
return path
|
|
|
|
|
|
def prepare_frontend():
|
|
if not os.path.exists("./frontend/dist"):
|
|
log.info("frontend not found, try clone...")
|
|
log.info("checking git installation...")
|
|
if not shutil.which("git"):
|
|
if sys.platform == "win32":
|
|
git_path = find_windows_git()
|
|
|
|
if git_path is not None:
|
|
log.info(f"Git not found, but found git in {git_path}, add it to PATH")
|
|
os.environ["PATH"] += os.pathsep + os.path.dirname(git_path)
|
|
return
|
|
else:
|
|
log.error("git not found, please install git first")
|
|
sys.exit(1)
|
|
subprocess.run(["git", "submodule", "init"])
|
|
subprocess.run(["git", "submodule", "update"])
|
|
|
|
|
|
def remove_warnings():
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
if sys.platform == "win32":
|
|
# disable triton on windows
|
|
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
|
|
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
|
|
|
|
|
|
def check_dirs(dirs: List):
|
|
for d in dirs:
|
|
if not os.path.exists(d):
|
|
os.makedirs(d)
|
|
|
|
|
|
def is_installed(package, friendly: str = None):
|
|
#
|
|
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
|
#
|
|
|
|
# Remove brackets and their contents from the line using regular expressions
|
|
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
|
package = re.sub(r'\[.*?\]', '', package)
|
|
|
|
try:
|
|
if friendly:
|
|
pkgs = friendly.split()
|
|
else:
|
|
pkgs = [
|
|
p
|
|
for p in package.split()
|
|
if not p.startswith('-') and not p.startswith('=')
|
|
]
|
|
pkgs = [
|
|
p.split('/')[-1] for p in pkgs
|
|
] # get only package name if installing from URL
|
|
|
|
for pkg in pkgs:
|
|
if '>=' in pkg:
|
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
|
|
elif '==' in pkg:
|
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
|
|
else:
|
|
pkg_name, pkg_version = pkg.strip(), None
|
|
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name, None)
|
|
if spec is None:
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
|
if spec is None:
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
|
|
|
|
if spec is not None:
|
|
version = pkg_resources.get_distribution(pkg_name).version
|
|
# log.debug(f'Package version found: {pkg_name} {version}')
|
|
|
|
if pkg_version is not None:
|
|
if '>=' in pkg:
|
|
ok = version >= pkg_version
|
|
else:
|
|
ok = version == pkg_version
|
|
|
|
if not ok:
|
|
log.info(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
|
|
return False
|
|
else:
|
|
log.warning(f'Package version not found: {pkg_name}')
|
|
return False
|
|
|
|
return True
|
|
except ModuleNotFoundError:
|
|
log.warning(f'Package not installed: {pkgs}')
|
|
return False
|
|
|
|
|
|
def validate_requirements(requirements_file: str):
|
|
with open(requirements_file, 'r', encoding='utf8') as f:
|
|
lines = [
|
|
line.strip()
|
|
for line in f.readlines()
|
|
if line.strip() != ''
|
|
and not line.startswith("#")
|
|
and not (line.startswith("-") and not line.startswith("--index-url "))
|
|
and line is not None
|
|
and "# skip_verify" not in line
|
|
]
|
|
|
|
index_url = ""
|
|
for line in lines:
|
|
if line.startswith("--index-url "):
|
|
index_url = line.replace("--index-url ", "")
|
|
continue
|
|
|
|
if not is_installed(line):
|
|
if index_url != "":
|
|
run_pip(f"install {line} --index-url {index_url}", line, live=True)
|
|
else:
|
|
run_pip(f"install {line}", line, live=True)
|
|
|
|
|
|
def setup_windows_bitsandbytes():
|
|
if sys.platform != "win32":
|
|
return
|
|
|
|
bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple")
|
|
bnb_package = "bitsandbytes==0.41.1"
|
|
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
|
|
|
|
installed_bnb = is_installed(bnb_package)
|
|
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0
|
|
|
|
if not installed_bnb or not bnb_cuda_setup:
|
|
log.error("detected wrong install of bitsandbytes, reinstall it")
|
|
run_pip(f"uninstall bitsandbytes -y", "bitsandbytes", live=True)
|
|
run_pip(f"install {bnb_package} --index-url {bnb_windows_index}", bnb_package, live=True)
|