324 lines
11 KiB
Python
324 lines
11 KiB
Python
import locale
|
|
import os
|
|
import platform
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import socket
|
|
import sysconfig
|
|
from typing import List
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import pkg_resources
|
|
|
|
from mikazuki.log import log
|
|
|
|
python_bin = sys.executable
|
|
|
|
|
|
def base_dir_path():
|
|
return Path(__file__).parents[1].absolute()
|
|
|
|
|
|
def find_windows_git():
|
|
possible_paths = ["git\\bin\\git.exe", "git\\cmd\\git.exe", "Git\\mingw64\\libexec\\git-core\\git.exe", "C:\\Program Files\\Git\\cmd\\git.exe"]
|
|
for path in possible_paths:
|
|
if os.path.exists(path):
|
|
return path
|
|
|
|
|
|
def prepare_git():
|
|
if shutil.which("git"):
|
|
return True
|
|
|
|
log.info("Finding git...")
|
|
|
|
if sys.platform == "win32":
|
|
git_path = find_windows_git()
|
|
|
|
if git_path is not None:
|
|
log.info(f"Git not found, but found git in {git_path}, add it to PATH")
|
|
os.environ["PATH"] += os.pathsep + os.path.dirname(git_path)
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
log.error("git not found, please install git first")
|
|
return False
|
|
|
|
|
|
def prepare_submodules():
|
|
frontend_path = base_dir_path() / "frontend" / "dist"
|
|
tag_editor_path = base_dir_path() / "mikazuki" / "dataset-tag-editor" / "scripts"
|
|
|
|
if not os.path.exists(frontend_path) or not os.path.exists(tag_editor_path):
|
|
log.info("submodule not found, try clone...")
|
|
log.info("checking git installation...")
|
|
if not prepare_git():
|
|
log.error("git not found, please install git first")
|
|
sys.exit(1)
|
|
subprocess.run(["git", "submodule", "init"])
|
|
subprocess.run(["git", "submodule", "update"])
|
|
|
|
|
|
def git_tag(path: str) -> str:
|
|
try:
|
|
return subprocess.check_output(["git", "-C", path, "describe", "--tags"]).strip().decode("utf-8")
|
|
except Exception as e:
|
|
return "<none>"
|
|
|
|
|
|
def check_dirs(dirs: List):
|
|
for d in dirs:
|
|
if not os.path.exists(d):
|
|
os.makedirs(d)
|
|
|
|
|
|
def run(command,
|
|
desc: Optional[str] = None,
|
|
errdesc: Optional[str] = None,
|
|
custom_env: Optional[list] = None,
|
|
live: Optional[bool] = True,
|
|
shell: Optional[bool] = None):
|
|
|
|
if shell is None:
|
|
shell = False if sys.platform == "win32" else True
|
|
|
|
if desc is not None:
|
|
print(desc)
|
|
|
|
if live:
|
|
result = subprocess.run(command, shell=shell, env=os.environ if custom_env is None else custom_env)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
|
Command: {command}
|
|
Error code: {result.returncode}""")
|
|
|
|
return ""
|
|
|
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
|
shell=shell, env=os.environ if custom_env is None else custom_env)
|
|
|
|
if result.returncode != 0:
|
|
message = f"""{errdesc or 'Error running command'}.
|
|
Command: {command}
|
|
Error code: {result.returncode}
|
|
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else '<empty>'}
|
|
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else '<empty>'}
|
|
"""
|
|
raise RuntimeError(message)
|
|
|
|
return result.stdout.decode(encoding="utf8", errors="ignore")
|
|
|
|
|
|
def is_installed(package, friendly: str = None):
|
|
#
|
|
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
|
#
|
|
|
|
# Remove brackets and their contents from the line using regular expressions
|
|
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
|
package = re.sub(r'\[.*?\]', '', package)
|
|
|
|
try:
|
|
if friendly:
|
|
pkgs = friendly.split()
|
|
else:
|
|
pkgs = [
|
|
p
|
|
for p in package.split()
|
|
if not p.startswith('-') and not p.startswith('=')
|
|
]
|
|
pkgs = [
|
|
p.split('/')[-1] for p in pkgs
|
|
] # get only package name if installing from URL
|
|
|
|
for pkg in pkgs:
|
|
if '>=' in pkg:
|
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
|
|
elif '==' in pkg:
|
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
|
|
else:
|
|
pkg_name, pkg_version = pkg.strip(), None
|
|
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name, None)
|
|
if spec is None:
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
|
if spec is None:
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
|
|
|
|
if spec is not None:
|
|
version = pkg_resources.get_distribution(pkg_name).version
|
|
# log.debug(f'Package version found: {pkg_name} {version}')
|
|
|
|
if pkg_version is not None:
|
|
if '>=' in pkg:
|
|
ok = version >= pkg_version
|
|
else:
|
|
ok = version == pkg_version
|
|
|
|
if not ok:
|
|
log.info(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
|
|
return False
|
|
else:
|
|
log.warning(f'Package version not found: {pkg_name}')
|
|
return False
|
|
|
|
return True
|
|
except ModuleNotFoundError:
|
|
log.warning(f'Package not installed: {pkgs}')
|
|
return False
|
|
|
|
|
|
def validate_requirements(requirements_file: str):
|
|
with open(requirements_file, 'r', encoding='utf8') as f:
|
|
lines = [
|
|
line.strip()
|
|
for line in f.readlines()
|
|
if line.strip() != ''
|
|
and not line.startswith("#")
|
|
and not (line.startswith("-") and not line.startswith("--index-url "))
|
|
and line is not None
|
|
and "# skip_verify" not in line
|
|
]
|
|
|
|
index_url = ""
|
|
for line in lines:
|
|
if line.startswith("--index-url "):
|
|
index_url = line.replace("--index-url ", "")
|
|
continue
|
|
|
|
if not is_installed(line):
|
|
if index_url != "":
|
|
run_pip(f"install {line} --index-url {index_url}", line, live=True)
|
|
else:
|
|
run_pip(f"install {line}", line, live=True)
|
|
|
|
|
|
def setup_windows_bitsandbytes():
|
|
if sys.platform != "win32":
|
|
return
|
|
|
|
# bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple")
|
|
bnb_package = "bitsandbytes==0.44.0"
|
|
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
|
|
|
|
installed_bnb = is_installed(bnb_package)
|
|
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0
|
|
|
|
if not installed_bnb or not bnb_cuda_setup:
|
|
log.error("detected wrong install of bitsandbytes, reinstall it")
|
|
run_pip(f"uninstall bitsandbytes -y", "bitsandbytes", live=True)
|
|
run_pip(f"install {bnb_package}", bnb_package, live=True)
|
|
|
|
|
|
def setup_onnxruntime():
|
|
onnx_version = "1.18.1"
|
|
index_url = None
|
|
|
|
try:
|
|
import torch
|
|
torch_version = torch.__version__
|
|
if "cu12" in torch_version:
|
|
# for cuda 12
|
|
onnx_version = f"1.18.1"
|
|
index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
|
|
except ImportError:
|
|
log.error("torch not found")
|
|
|
|
if sys.platform == "linux":
|
|
libc_ver = platform.libc_ver()
|
|
if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27":
|
|
onnx_version = "1.16.3"
|
|
|
|
onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version)
|
|
|
|
if not is_installed(f"onnxruntime-gpu=={onnx_version}"):
|
|
log.info("uninstalling wrong onnxruntime version")
|
|
# run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True)
|
|
run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True)
|
|
run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True)
|
|
|
|
log.info(f"installing onnxruntime")
|
|
run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True)
|
|
if index_url:
|
|
run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True)
|
|
else:
|
|
run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True)
|
|
|
|
|
|
def run_pip(command, desc=None, live=False):
|
|
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
|
|
|
|
|
def check_run(file: str) -> bool:
|
|
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
|
|
log.info(result.stdout.decode("utf-8").strip())
|
|
return result.returncode == 0
|
|
|
|
|
|
def prepare_environment(disable_auto_mirror: bool = True):
|
|
if sys.platform == "win32":
|
|
# disable triton on windows
|
|
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
|
|
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
|
|
|
|
if not disable_auto_mirror and locale.getdefaultlocale()[0] == "zh_CN":
|
|
log.info("detected locale zh_CN, use pip & huggingface mirrors")
|
|
os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html")
|
|
os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple")
|
|
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
|
|
|
if not os.environ.get("PATH"):
|
|
os.environ["PATH"] = os.path.dirname(sys.executable)
|
|
|
|
prepare_submodules()
|
|
|
|
check_dirs(["config/autosave", "logs"])
|
|
|
|
# if not check_run("mikazuki/scripts/torch_check.py"):
|
|
# sys.exit(1)
|
|
|
|
validate_requirements("requirements.txt")
|
|
setup_windows_bitsandbytes()
|
|
|
|
# if not skip_prepare_onnxruntime:
|
|
# setup_onnxruntime()
|
|
|
|
|
|
def catch_exception(f):
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return f(*args, **kwargs)
|
|
except Exception as e:
|
|
log.error(f"An error occurred: {e}")
|
|
return wrapper
|
|
|
|
|
|
def check_port_avaliable(port: int):
|
|
try:
|
|
s = socket.socket()
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("127.0.0.1", port))
|
|
s.close()
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
|
|
def find_avaliable_ports(port_init: int, port_range: int):
|
|
server_ports = range(port_init, port_range)
|
|
|
|
for p in server_ports:
|
|
if check_port_avaliable(p):
|
|
return p
|
|
|
|
log.error(f"error finding avaliable ports in range: {port_init} -> {port_range}")
|
|
return None
|