automatic/modules/loader.py

285 lines
12 KiB
Python

from __future__ import annotations
from functools import partial
import os
import re
import sys
import types
import logging
import warnings
import urllib3
from modules import timer, errors
from modules.logger import log
initialized = False
errors.install()
logging.getLogger("DeepSpeed").disabled = True
timer.startup.record("loader")
log.debug('Initializing: libraries')
debug = os.environ.get('SD_LOAD_DEBUG')
def report(msg: str, e: Exception):
log.error(f'Loader: {msg} {e}')
log.error('Please restart the app to fix this issue')
if debug:
errors.display(e, msg)
sys.exit(1)
np = None
try:
os.environ.setdefault('NEP50_DISABLE_WARNING', '1')
import numpy as np # pylint: disable=W0611,C0411
import numpy.random # pylint: disable=W0611,C0411 # this causes failure if numpy version changed
def obj2sctype(obj):
return np.dtype(obj).type
if np.__version__.startswith('2.'): # monkeypatch for np==1.2 compatibility
np.obj2sctype = obj2sctype # noqa: NPY201
np.bool8 = np.bool
np.float_ = np.float64 # noqa: NPY201
def dummy_npwarn_decorator_factory():
def npwarn_decorator(x):
return x
return npwarn_decorator
np._no_nep50_warning = getattr(np, '_no_nep50_warning', dummy_npwarn_decorator_factory) # pylint: disable=protected-access
else:
log.warning(f'Loader: numpy=={np.__version__} unsupported')
except Exception as e:
report(f'numpy=={np.__version__ if np is not None else None}', e)
timer.startup.record("numpy")
scipy = None
try:
import scipy # pylint: disable=W0611,C0411
except Exception as e:
report(f'scipy=={scipy.__version__ if scipy is not None else None}', e)
timer.startup.record("scipy")
try:
import atexit
import torch._inductor.async_compile as ac
atexit.unregister(ac.shutdown_compile_workers)
except Exception:
pass
import torch # pylint: disable=C0411
if torch.__version__.startswith('2.5.0'):
log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}')
torch.backends.cuda.enable_cudnn_sdp(False)
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error,unused-import
log.debug(f'Load IPEX=={ipex.__version__}')
except Exception:
pass
try:
import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports
except Exception:
log.warning('Loader: torch is not built with distributed support')
try:
import math
cores = os.cpu_count()
affinity = len(os.sched_getaffinity(0)) # pylint: disable=no-member
threads = torch.get_num_threads()
if threads < (affinity / 2):
torch.set_num_threads(math.floor(affinity / 2))
threads = torch.get_num_threads()
log.debug(f'System: cores={cores} affinity={affinity} threads={threads}')
except Exception:
pass
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
torchvision = None
try:
import torchvision # pylint: disable=W0611,C0411
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411
except Exception as e:
report(f'torchvision=={torchvision.__version__ if torchvision is not None else None}', e)
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("pytorch_lightning").disabled = True
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=FutureWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", message="numpy.dtype size changed")
try:
import torch._logging # pylint: disable=ungrouped-imports
_compile_debug = os.environ.get('SD_COMPILE_DEBUG', None) is not None
if _compile_debug:
torch._logging._internal.DEFAULT_LOG_LEVEL = logging.ERROR # pylint: disable=protected-access
torch._logging.set_logs(dynamo=logging.WARNING, aot=logging.WARNING, inductor=logging.WARNING) # pylint: disable=protected-access
else:
torch._logging._internal.DEFAULT_LOG_LEVEL = logging.ERROR # pylint: disable=protected-access
torch._logging.set_logs(all=logging.ERROR, bytecode=False, aot_graphs=False, aot_joint_graph=False, ddp_graphs=False, graph=False, graph_code=False, graph_breaks=False, graph_sizes=False, guards=False, recompiles=False, recompiles_verbose=False, trace_source=False, trace_call=False, trace_bytecode=False, output_code=False, kernel_code=False, schedule=False, perf_hints=False, post_grad_graphs=False, onnx_diagnostics=False, fusion=False, overlap=False, export=None, modules=None, cudagraphs=False, sym_node=False, compiled_autograd_verbose=False) # pylint: disable=protected-access
import torch._dynamo
torch._dynamo.config.verbose = _compile_debug # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = not _compile_debug # pylint: disable=protected-access
except Exception as e:
log.warning(f'Torch logging: {e}')
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
timer.startup.record("torch")
try:
import bitsandbytes # pylint: disable=unused-import
_bnb = True
except Exception:
_bnb = False
timer.startup.record("bnb")
huggingface_hub = None
try:
import huggingface_hub # pylint: disable=W0611,C0411
logging.getLogger("huggingface_hub.file_download").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub.utils._http").setLevel(logging.ERROR)
timer.startup.record("hfhub")
except Exception as e:
report(f'huggingface_hub=={huggingface_hub.__version__ if "huggingface_hub" in sys.modules else None}', e)
timer.startup.record("hub")
accelerate = None
try:
import accelerate # pylint: disable=W0611,C0411
except Exception as e:
report(f'accelerate=={accelerate.__version__ if "accelerate" in sys.modules else None}', e)
timer.startup.record("accelerate")
pydantic = None
try:
import pydantic # pylint: disable=W0611,C0411
except Exception as e:
report(f'pydantic=={pydantic.__version__ if "pydantic" in sys.modules else None}', e)
timer.startup.record("pydantic")
try:
# transformers==5.x has different dependency stack so switching between v4 and v5 becomes very painful
# this temporarily disables dependency version checks so we can use either v4 or v5 until we drop support for v4
fake_version_check = types.ModuleType("transformers.dependency_versions_check")
sys.modules["transformers.dependency_versions_check"] = fake_version_check # disable transformers version checks
fake_version_check.dep_version_check = lambda pkg, hint=None: None
except Exception:
pass
transformers = None
try:
import transformers # pylint: disable=W0611,C0411
from transformers import logging as transformers_logging # pylint: disable=W0611,C0411
transformers_logging.set_verbosity_error()
except Exception as e:
report(f'transformers=={transformers.__version__ if "transformers" in sys.modules else None}', e)
timer.startup.record("transformers")
try:
import onnxruntime # pylint: disable=W0611,C0411
onnxruntime.set_default_logger_severity(4)
onnxruntime.set_default_logger_verbosity(1)
onnxruntime.disable_telemetry_events()
except Exception as e:
log.warning(f'Torch onnxruntime: {e}')
timer.startup.record("onnx")
timer.startup.record("fastapi")
import gradio # pylint: disable=W0611,C0411
timer.startup.record("gradio")
errors.install([gradio])
# patch different progress bars
import tqdm as tqdm_lib # pylint: disable=C0411
from tqdm.rich import tqdm # pylint: disable=W0611,C0411
try:
logging.getLogger("diffusers.guiders").setLevel(logging.ERROR)
logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR)
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
diffusers.utils.import_utils._bitsandbytes_available = _bnb # pylint: disable=protected-access
import diffusers # pylint: disable=W0611,C0411
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
diffusers.loaders.single_file.logging.tqdm = partial(tqdm, unit='C')
timer.startup.record("diffusers")
except Exception as e:
log.error(f'Loader: diffusers=={diffusers.__version__ if "diffusers" in sys.modules else None} {e}')
log.error('Please restart re-run the installer')
sys.exit(1)
try:
import pillow_jxl # pylint: disable=W0611,C0411
except Exception:
pass
from PIL import Image # pylint: disable=W0611,C0411
timer.startup.record("pillow")
import cv2 # pylint: disable=W0611,C0411
timer.startup.record("cv2")
class _tqdm_cls:
def __call__(self, *args, **kwargs):
bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m'
return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs)
class _tqdm_old(tqdm_lib.tqdm):
def __init__(self, *args, **kwargs):
kwargs.pop("name", None)
kwargs['bar_format'] = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m'
kwargs['ncols'] = 80
super().__init__(*args, **kwargs)
try:
transformers.utils.logging.tqdm = _tqdm_cls()
diffusers.pipelines.pipeline_utils.logging.tqdm = _tqdm_cls()
huggingface_hub._snapshot_download.hf_tqdm = _tqdm_old # pylint: disable=protected-access
except Exception:
pass
def get_packages():
return {
"torch": getattr(torch, "__long_version__", torch.__version__),
"diffusers": diffusers.__version__,
"gradio": gradio.__version__,
"transformers": transformers.__version__,
"accelerate": accelerate.__version__,
"hub": huggingface_hub.__version__,
}
try:
import torchvision.transforms.functional_tensor # pylint: disable=unused-import, ungrouped-imports
except ImportError:
try:
import torchvision.transforms.functional as functional
sys.modules["torchvision.transforms.functional_tensor"] = functional
except ImportError:
pass # shrug...
deprecate_diffusers = diffusers.utils.deprecation_utils.deprecate
def deprecate_warn(*args, **kwargs):
try:
deprecate_diffusers(*args, **kwargs)
except Exception as e:
log.warning(f'Deprecation: {e}')
diffusers.utils.deprecation_utils.deprecate = deprecate_warn
diffusers.utils.deprecate = deprecate_warn
class VersionString(str): # support both string and tuple for version check
def __ge__(self, version):
if isinstance(version, tuple):
version_tuple = re.findall(r'\d+', torch.__version__.split('+')[0])
version_tuple = tuple(int(x) for x in version_tuple[:3])
return version_tuple >= version
return super().__ge__(version)
torch.__version__ = VersionString(torch.__version__)
log.info(f'Torch: torch=={torch.__version__} torchvision=={torchvision.__version__}')
log.info(f'Packages: diffusers=={diffusers.__version__} transformers=={transformers.__version__} accelerate=={accelerate.__version__} gradio=={gradio.__version__} pydantic=={pydantic.__version__} numpy=={np.__version__} cv2=={cv2.__version__}')