from __future__ import annotations from functools import partial import os import re import sys import types import logging import warnings import urllib3 from modules import timer, errors from modules.logger import log initialized = False errors.install() logging.getLogger("DeepSpeed").disabled = True timer.startup.record("loader") log.debug('Initializing: libraries') def report(msg: str, e: Exception): log.error(f'Loader: {msg} {e}') log.error('Please restart the app to fix this issue') sys.exit(1) np = None try: os.environ.setdefault('NEP50_DISABLE_WARNING', '1') import numpy as np # pylint: disable=W0611,C0411 import numpy.random # pylint: disable=W0611,C0411 # this causes failure if numpy version changed def obj2sctype(obj): return np.dtype(obj).type if np.__version__.startswith('2.'): # monkeypatch for np==1.2 compatibility np.obj2sctype = obj2sctype # noqa: NPY201 np.bool8 = np.bool np.float_ = np.float64 # noqa: NPY201 def dummy_npwarn_decorator_factory(): def npwarn_decorator(x): return x return npwarn_decorator np._no_nep50_warning = getattr(np, '_no_nep50_warning', dummy_npwarn_decorator_factory) # pylint: disable=protected-access else: log.warning(f'Loader: numpy=={np.__version__} unsupported') except Exception as e: report(f'numpy=={np.__version__ if np is not None else None}', e) timer.startup.record("numpy") scipy = None try: import scipy # pylint: disable=W0611,C0411 except Exception as e: report(f'scipy=={scipy.__version__ if scipy is not None else None}', e) timer.startup.record("scipy") try: import atexit import torch._inductor.async_compile as ac atexit.unregister(ac.shutdown_compile_workers) except Exception: pass import torch # pylint: disable=C0411 if torch.__version__.startswith('2.5.0'): log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}') torch.backends.cuda.enable_cudnn_sdp(False) try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error,unused-import log.debug(f'Load IPEX=={ipex.__version__}') except Exception: pass try: import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports except Exception: log.warning('Loader: torch is not built with distributed support') try: import math cores = os.cpu_count() affinity = len(os.sched_getaffinity(0)) # pylint: disable=no-member threads = torch.get_num_threads() if threads < (affinity / 2): torch.set_num_threads(math.floor(affinity / 2)) threads = torch.get_num_threads() log.debug(f'System: cores={cores} affinity={affinity} threads={threads}') except Exception: pass urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") torchvision = None try: import torchvision # pylint: disable=W0611,C0411 import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411 except Exception as e: report(f'torchvision=={torchvision.__version__ if torchvision is not None else None}', e) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("pytorch_lightning").disabled = True warnings.filterwarnings(action="ignore", category=DeprecationWarning) warnings.filterwarnings(action="ignore", category=FutureWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") warnings.filterwarnings(action="ignore", message="numpy.dtype size changed") try: import torch._logging # pylint: disable=ungrouped-imports _compile_debug = os.environ.get('SD_COMPILE_DEBUG', None) is not None if _compile_debug: torch._logging._internal.DEFAULT_LOG_LEVEL = logging.ERROR # pylint: disable=protected-access torch._logging.set_logs(dynamo=logging.WARNING, aot=logging.WARNING, inductor=logging.WARNING) # pylint: disable=protected-access else: torch._logging._internal.DEFAULT_LOG_LEVEL = logging.ERROR # pylint: disable=protected-access torch._logging.set_logs(all=logging.ERROR, bytecode=False, aot_graphs=False, aot_joint_graph=False, ddp_graphs=False, graph=False, graph_code=False, graph_breaks=False, graph_sizes=False, guards=False, recompiles=False, recompiles_verbose=False, trace_source=False, trace_call=False, trace_bytecode=False, output_code=False, kernel_code=False, schedule=False, perf_hints=False, post_grad_graphs=False, onnx_diagnostics=False, fusion=False, overlap=False, export=None, modules=None, cudagraphs=False, sym_node=False, compiled_autograd_verbose=False) # pylint: disable=protected-access import torch._dynamo torch._dynamo.config.verbose = _compile_debug # pylint: disable=protected-access torch._dynamo.config.suppress_errors = not _compile_debug # pylint: disable=protected-access except Exception as e: log.warning(f'Torch logging: {e}') if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) timer.startup.record("torch") try: import bitsandbytes # pylint: disable=unused-import _bnb = True except Exception: _bnb = False timer.startup.record("bnb") huggingface_hub = None try: import huggingface_hub # pylint: disable=W0611,C0411 logging.getLogger("huggingface_hub.file_download").setLevel(logging.ERROR) logging.getLogger("huggingface_hub.utils._http").setLevel(logging.ERROR) timer.startup.record("hfhub") except Exception as e: report(f'huggingface_hub=={huggingface_hub.__version__ if "huggingface_hub" in sys.modules else None}', e) timer.startup.record("hub") accelerate = None try: import accelerate # pylint: disable=W0611,C0411 except Exception as e: report(f'accelerate=={accelerate.__version__ if "accelerate" in sys.modules else None}', e) timer.startup.record("accelerate") pydantic = None try: import pydantic # pylint: disable=W0611,C0411 except Exception as e: report(f'pydantic=={pydantic.__version__ if "pydantic" in sys.modules else None}', e) timer.startup.record("pydantic") try: # transformers==5.x has different dependency stack so switching between v4 and v5 becomes very painful # this temporarily disables dependency version checks so we can use either v4 or v5 until we drop support for v4 fake_version_check = types.ModuleType("transformers.dependency_versions_check") sys.modules["transformers.dependency_versions_check"] = fake_version_check # disable transformers version checks fake_version_check.dep_version_check = lambda pkg, hint=None: None except Exception: pass transformers = None try: import transformers # pylint: disable=W0611,C0411 from transformers import logging as transformers_logging # pylint: disable=W0611,C0411 transformers_logging.set_verbosity_error() except Exception as e: report(f'transformers=={transformers.__version__ if "transformers" in sys.modules else None}', e) timer.startup.record("transformers") try: import onnxruntime # pylint: disable=W0611,C0411 onnxruntime.set_default_logger_severity(4) onnxruntime.set_default_logger_verbosity(1) onnxruntime.disable_telemetry_events() except Exception as e: log.warning(f'Torch onnxruntime: {e}') timer.startup.record("onnx") timer.startup.record("fastapi") import gradio # pylint: disable=W0611,C0411 timer.startup.record("gradio") errors.install([gradio]) # patch different progress bars import tqdm as tqdm_lib # pylint: disable=C0411 from tqdm.rich import tqdm # pylint: disable=W0611,C0411 try: logging.getLogger("diffusers.guiders").setLevel(logging.ERROR) logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR) import diffusers.utils.import_utils # pylint: disable=W0611,C0411 diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access diffusers.utils.import_utils._bitsandbytes_available = _bnb # pylint: disable=protected-access import diffusers # pylint: disable=W0611,C0411 import diffusers.loaders.single_file # pylint: disable=W0611,C0411 diffusers.loaders.single_file.logging.tqdm = partial(tqdm, unit='C') timer.startup.record("diffusers") except Exception as e: log.error(f'Loader: diffusers=={diffusers.__version__ if "diffusers" in sys.modules else None} {e}') log.error('Please restart re-run the installer') sys.exit(1) try: import pillow_jxl # pylint: disable=W0611,C0411 except Exception: pass from PIL import Image # pylint: disable=W0611,C0411 timer.startup.record("pillow") import cv2 # pylint: disable=W0611,C0411 timer.startup.record("cv2") class _tqdm_cls: def __call__(self, *args, **kwargs): bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m' return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs) class _tqdm_old(tqdm_lib.tqdm): def __init__(self, *args, **kwargs): kwargs.pop("name", None) kwargs['bar_format'] = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m' kwargs['ncols'] = 80 super().__init__(*args, **kwargs) transformers.utils.logging.tqdm = _tqdm_cls() diffusers.pipelines.pipeline_utils.logging.tqdm = _tqdm_cls() huggingface_hub._snapshot_download.hf_tqdm = _tqdm_old # pylint: disable=protected-access def get_packages(): return { "torch": getattr(torch, "__long_version__", torch.__version__), "diffusers": diffusers.__version__, "gradio": gradio.__version__, "transformers": transformers.__version__, "accelerate": accelerate.__version__, "hub": huggingface_hub.__version__, } try: import torchvision.transforms.functional_tensor # pylint: disable=unused-import, ungrouped-imports except ImportError: try: import torchvision.transforms.functional as functional sys.modules["torchvision.transforms.functional_tensor"] = functional except ImportError: pass # shrug... deprecate_diffusers = diffusers.utils.deprecation_utils.deprecate def deprecate_warn(*args, **kwargs): try: deprecate_diffusers(*args, **kwargs) except Exception as e: log.warning(f'Deprecation: {e}') diffusers.utils.deprecation_utils.deprecate = deprecate_warn diffusers.utils.deprecate = deprecate_warn class VersionString(str): # support both string and tuple for version check def __ge__(self, version): if isinstance(version, tuple): version_tuple = re.findall(r'\d+', torch.__version__.split('+')[0]) version_tuple = tuple(int(x) for x in version_tuple[:3]) return version_tuple >= version return super().__ge__(version) torch.__version__ = VersionString(torch.__version__) log.info(f'Torch: torch=={torch.__version__} torchvision=={torchvision.__version__}') log.info(f'Packages: diffusers=={diffusers.__version__} transformers=={transformers.__version__} accelerate=={accelerate.__version__} gradio=={gradio.__version__} pydantic=={pydantic.__version__} numpy=={np.__version__} cv2=={cv2.__version__}')