lora-scripts/mikazuki/launch_utils.py

340 lines
11 KiB
Python

import locale
import os
import platform
import re
import shutil
import subprocess
import sys
import socket
import sysconfig
from typing import List
from pathlib import Path
from typing import Optional
import pkg_resources
from mikazuki.log import log
python_bin = sys.executable
def base_dir_path():
return Path(__file__).parents[1].absolute()
def find_windows_git():
possible_paths = ["git\\bin\\git.exe", "git\\cmd\\git.exe", "Git\\mingw64\\libexec\\git-core\\git.exe", "C:\\Program Files\\Git\\cmd\\git.exe"]
for path in possible_paths:
if os.path.exists(path):
return path
def prepare_git():
if shutil.which("git"):
return True
log.info("Finding git...")
if sys.platform == "win32":
git_path = find_windows_git()
if git_path is not None:
log.info(f"Git not found, but found git in {git_path}, add it to PATH")
os.environ["PATH"] += os.pathsep + os.path.dirname(git_path)
return True
else:
return False
else:
log.error("git not found, please install git first")
return False
def prepare_submodules():
frontend_path = base_dir_path() / "frontend" / "dist"
tag_editor_path = base_dir_path() / "mikazuki" / "dataset-tag-editor" / "scripts"
if not os.path.exists(frontend_path) or not os.path.exists(tag_editor_path):
log.info("submodule not found, try clone...")
log.info("checking git installation...")
if not prepare_git():
log.error("git not found, please install git first")
sys.exit(1)
subprocess.run(["git", "submodule", "init"])
subprocess.run(["git", "submodule", "update"])
def git_tag(path: str) -> str:
try:
return subprocess.check_output(["git", "-C", path, "describe", "--tags"]).strip().decode("utf-8")
except Exception as e:
return "<none>"
def check_dirs(dirs: List):
for d in dirs:
if not os.path.exists(d):
os.makedirs(d)
def run(command,
desc: Optional[str] = None,
errdesc: Optional[str] = None,
custom_env: Optional[list] = None,
live: Optional[bool] = True,
shell: Optional[bool] = None):
if shell is None:
shell = False if sys.platform == "win32" else True
if desc is not None:
print(desc)
if live:
result = subprocess.run(command, shell=shell, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=shell, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def is_installed(package, friendly: str = None):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
# Remove brackets and their contents from the line using regular expressions
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
package = re.sub(r'\[.*?\]', '', package)
try:
if friendly:
pkgs = friendly.split()
else:
pkgs = [
p
for p in package.split()
if not p.startswith('-') and not p.startswith('=')
]
pkgs = [
p.split('/')[-1] for p in pkgs
] # get only package name if installing from URL
for pkg in pkgs:
if '>=' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
elif '==' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
else:
pkg_name, pkg_version = pkg.strip(), None
spec = pkg_resources.working_set.by_key.get(pkg_name, None)
if spec is None:
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
if spec is None:
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
if spec is not None:
version = pkg_resources.get_distribution(pkg_name).version
# log.debug(f'Package version found: {pkg_name} {version}')
if pkg_version is not None:
if '>=' in pkg:
ok = version >= pkg_version
else:
ok = version == pkg_version
if not ok:
log.info(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
return False
else:
log.warning(f'Package version not found: {pkg_name}')
return False
return True
except ModuleNotFoundError:
log.warning(f'Package not installed: {pkgs}')
return False
def validate_requirements(requirements_file: str):
with open(requirements_file, 'r', encoding='utf8') as f:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith("#")
and not (line.startswith("-") and not line.startswith("--index-url "))
and line is not None
and "# skip_verify" not in line
]
index_url = ""
for line in lines:
if line.startswith("--index-url "):
index_url = line.replace("--index-url ", "")
continue
if not is_installed(line):
if index_url != "":
run_pip(f"install {line} --index-url {index_url}", line, live=True)
else:
run_pip(f"install {line}", line, live=True)
def setup_windows_bitsandbytes():
if sys.platform != "win32":
return
# bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple")
bnb_package = "bitsandbytes==0.45.3"
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
installed_bnb = is_installed("bitsandbytes") # don't check version here
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0
if not installed_bnb or not bnb_cuda_setup:
log.error("detected wrong install of bitsandbytes, reinstall it")
run_pip(f"uninstall bitsandbytes -y", "bitsandbytes", live=True)
run_pip(f"install {bnb_package}", bnb_package, live=True)
def setup_onnxruntime():
onnx_version = "1.18.1"
index_url = None
try:
import torch
torch_version = torch.__version__
if "cu12" in torch_version:
# for cuda 12
onnx_version = f"1.18.1"
index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
except ImportError:
log.error("torch not found")
if sys.platform == "linux":
libc_ver = platform.libc_ver()
if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27":
onnx_version = "1.16.3"
onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version)
if not is_installed(f"onnxruntime-gpu=={onnx_version}"):
log.info("uninstalling wrong onnxruntime version")
# run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True)
run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True)
run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True)
log.info(f"installing onnxruntime")
run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True)
if index_url:
run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True)
else:
run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True)
def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip())
return result.returncode == 0
def network_gfw_test(timeout=3):
try:
import requests
# requests will auto detect system proxies
response = requests.get("https://www.google.com", timeout=timeout)
if response.status_code == 200:
log.info("Network test passed")
return True
else:
log.error(f"Network test failed: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
log.error(f"Network test failed: {e}")
return False
def prepare_environment(disable_auto_mirror: bool = True):
if sys.platform == "win32":
# disable triton on windows
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
if not disable_auto_mirror and not network_gfw_test():
log.info("use pip & huggingface mirrors")
os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html")
os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple")
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
if not os.environ.get("PATH"):
os.environ["PATH"] = os.path.dirname(sys.executable)
prepare_submodules()
check_dirs(["config/autosave", "logs"])
# if not check_run("mikazuki/scripts/torch_check.py"):
# sys.exit(1)
validate_requirements("requirements.txt")
setup_windows_bitsandbytes()
# if not skip_prepare_onnxruntime:
# setup_onnxruntime()
def catch_exception(f):
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except Exception as e:
log.error(f"An error occurred: {e}")
return wrapper
def check_port_avaliable(port: int):
try:
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", port))
s.close()
return True
except:
return False
def find_avaliable_ports(port_init: int, port_range: int):
server_ports = range(port_init, port_range)
for p in server_ports:
if check_port_avaliable(p):
return p
log.error(f"error finding avaliable ports in range: {port_init} -> {port_range}")
return None