import locale import os import platform import re import shutil import subprocess import sys import socket import sysconfig from typing import List from pathlib import Path from typing import Optional import pkg_resources from mikazuki.log import log python_bin = sys.executable def base_dir_path(): return Path(__file__).parents[1].absolute() def find_windows_git(): possible_paths = ["git\\bin\\git.exe", "git\\cmd\\git.exe", "Git\\mingw64\\libexec\\git-core\\git.exe", "C:\\Program Files\\Git\\cmd\\git.exe"] for path in possible_paths: if os.path.exists(path): return path def prepare_git(): if shutil.which("git"): return True log.info("Finding git...") if sys.platform == "win32": git_path = find_windows_git() if git_path is not None: log.info(f"Git not found, but found git in {git_path}, add it to PATH") os.environ["PATH"] += os.pathsep + os.path.dirname(git_path) return True else: return False else: log.error("git not found, please install git first") return False def prepare_submodules(): frontend_path = base_dir_path() / "frontend" / "dist" tag_editor_path = base_dir_path() / "mikazuki" / "dataset-tag-editor" / "scripts" if not os.path.exists(frontend_path) or not os.path.exists(tag_editor_path): log.info("submodule not found, try clone...") log.info("checking git installation...") if not prepare_git(): log.error("git not found, please install git first") sys.exit(1) subprocess.run(["git", "submodule", "init"]) subprocess.run(["git", "submodule", "update"]) def git_tag(path: str) -> str: try: return subprocess.check_output(["git", "-C", path, "describe", "--tags"]).strip().decode("utf-8") except Exception as e: return "" def check_dirs(dirs: List): for d in dirs: if not os.path.exists(d): os.makedirs(d) def run(command, desc: Optional[str] = None, errdesc: Optional[str] = None, custom_env: Optional[list] = None, live: Optional[bool] = True, shell: Optional[bool] = None): if shell is None: shell = False if sys.platform == "win32" else True if desc is not None: print(desc) if live: result = subprocess.run(command, shell=shell, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'}. Command: {command} Error code: {result.returncode}""") return "" result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: message = f"""{errdesc or 'Error running command'}. Command: {command} Error code: {result.returncode} stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else ''} stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else ''} """ raise RuntimeError(message) return result.stdout.decode(encoding="utf8", errors="ignore") def is_installed(package, friendly: str = None): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # # Remove brackets and their contents from the line using regular expressions # e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 package = re.sub(r'\[.*?\]', '', package) try: if friendly: pkgs = friendly.split() else: pkgs = [ p for p in package.split() if not p.startswith('-') and not p.startswith('=') ] pkgs = [ p.split('/')[-1] for p in pkgs ] # get only package name if installing from URL for pkg in pkgs: if '>=' in pkg: pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] elif '==' in pkg: pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] else: pkg_name, pkg_version = pkg.strip(), None spec = pkg_resources.working_set.by_key.get(pkg_name, None) if spec is None: spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) if spec is None: spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) if spec is not None: version = pkg_resources.get_distribution(pkg_name).version # log.debug(f'Package version found: {pkg_name} {version}') if pkg_version is not None: if '>=' in pkg: ok = version >= pkg_version else: ok = version == pkg_version if not ok: log.info(f'Package wrong version: {pkg_name} {version} required {pkg_version}') return False else: log.warning(f'Package version not found: {pkg_name}') return False return True except ModuleNotFoundError: log.warning(f'Package not installed: {pkgs}') return False def validate_requirements(requirements_file: str): with open(requirements_file, 'r', encoding='utf8') as f: lines = [ line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith("#") and not (line.startswith("-") and not line.startswith("--index-url ")) and line is not None and "# skip_verify" not in line ] index_url = "" for line in lines: if line.startswith("--index-url "): index_url = line.replace("--index-url ", "") continue if not is_installed(line): if index_url != "": run_pip(f"install {line} --index-url {index_url}", line, live=True) else: run_pip(f"install {line}", line, live=True) def setup_windows_bitsandbytes(): if sys.platform != "win32": return # bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple") bnb_package = "bitsandbytes==0.44.0" bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes") installed_bnb = is_installed(bnb_package) bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0 if not installed_bnb or not bnb_cuda_setup: log.error("detected wrong install of bitsandbytes, reinstall it") run_pip(f"uninstall bitsandbytes -y", "bitsandbytes", live=True) run_pip(f"install {bnb_package}", bnb_package, live=True) def setup_onnxruntime(): onnx_version = "1.18.1" index_url = None try: import torch torch_version = torch.__version__ if "cu12" in torch_version: # for cuda 12 onnx_version = f"1.18.1" index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/" except ImportError: log.error("torch not found") if sys.platform == "linux": libc_ver = platform.libc_ver() if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27": onnx_version = "1.16.3" onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version) if not is_installed(f"onnxruntime-gpu=={onnx_version}"): log.info("uninstalling wrong onnxruntime version") # run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True) run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True) run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True) log.info(f"installing onnxruntime") run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True) if index_url: run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True) else: run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True) def run_pip(command, desc=None, live=False): return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) def check_run(file: str) -> bool: result = subprocess.run([python_bin, file], capture_output=True, shell=False) log.info(result.stdout.decode("utf-8").strip()) return result.returncode == 0 def prepare_environment(disable_auto_mirror: bool = True): if sys.platform == "win32": # disable triton on windows os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["PYTHONWARNINGS"] = "ignore::UserWarning" os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" if not disable_auto_mirror and locale.getdefaultlocale()[0] == "zh_CN": log.info("detected locale zh_CN, use pip & huggingface mirrors") os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html") os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple") os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com") if not os.environ.get("PATH"): os.environ["PATH"] = os.path.dirname(sys.executable) prepare_submodules() check_dirs(["config/autosave", "logs"]) # if not check_run("mikazuki/scripts/torch_check.py"): # sys.exit(1) validate_requirements("requirements.txt") setup_windows_bitsandbytes() # if not skip_prepare_onnxruntime: # setup_onnxruntime() def catch_exception(f): def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except Exception as e: log.error(f"An error occurred: {e}") return wrapper def check_port_avaliable(port: int): try: s = socket.socket() s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("127.0.0.1", port)) s.close() return True except: return False def find_avaliable_ports(port_init: int, port_range: int): server_ports = range(port_init, port_range) for p in server_ports: if check_port_avaliable(p): return p log.error(f"error finding avaliable ports in range: {port_init} -> {port_range}") return None