mirror of https://github.com/lllyasviel/Fooocus
5582 lines
214 KiB
Python
5582 lines
214 KiB
Python
import json
|
|
import os
|
|
import hashlib
|
|
import importlib
|
|
import gc
|
|
import re
|
|
import time
|
|
import ctypes
|
|
import ctypes.util
|
|
import inspect
|
|
from typing import Optional
|
|
|
|
|
|
_PIPELINE_CACHE = {}
|
|
_PROMPT_EMBED_CACHE = {}
|
|
_MAX_PROMPT_CACHE_ITEMS = 32
|
|
_TRANSFORMER_MAPPING_DECISION_CACHE = {}
|
|
_MAX_TRANSFORMER_MAPPING_DECISIONS = 32
|
|
_TRANSFORMER_MAPPING_CACHE_VERSION = "v1"
|
|
_PERSISTENT_TRANSFORMER_CACHE_VERSION = "v1"
|
|
_ENV_WARNING_ONCE = set()
|
|
_LAST_ZIMAGE_SOURCE_SIGNATURE = None
|
|
_TOKENIZER_JSON_SHA256 = {
|
|
"Tongyi-MAI/Z-Image-Turbo": "aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4",
|
|
}
|
|
ZIMAGE_COMPONENT_AUTO = "Auto (use model default)"
|
|
_ZIMAGE_GRANULAR_COMPONENT_NAMES = ("text_encoder", "text_encoder_2", "transformer", "vae")
|
|
|
|
|
|
def _project_root() -> str:
|
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
|
|
|
def _pipeline_cache_key(
|
|
source_kind: str,
|
|
source_path: str,
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
) -> str:
|
|
te = os.path.abspath(text_encoder_override) if text_encoder_override else "-"
|
|
vae = os.path.abspath(vae_override) if vae_override else "-"
|
|
return f"{source_kind}:{os.path.abspath(source_path)}:te={te}:vae={vae}"
|
|
|
|
|
|
def _single_file_identity(path: str) -> str:
|
|
abspath = os.path.abspath(path)
|
|
try:
|
|
st = os.stat(abspath)
|
|
return f"{abspath}:{int(st.st_size)}:{int(st.st_mtime_ns)}"
|
|
except OSError:
|
|
return f"{abspath}:missing"
|
|
|
|
|
|
def _keys_signature(keys: set[str]) -> str:
|
|
hasher = hashlib.sha1()
|
|
for key in sorted(keys):
|
|
hasher.update(key.encode("utf-8", errors="ignore"))
|
|
hasher.update(b"\0")
|
|
return f"{len(keys)}:{hasher.hexdigest()}"
|
|
|
|
|
|
def _transformer_mapping_cache_key(single_file_path: str, model_keys: set[str]) -> str:
|
|
return "|".join(
|
|
(
|
|
_TRANSFORMER_MAPPING_CACHE_VERSION,
|
|
_single_file_identity(single_file_path),
|
|
_keys_signature(model_keys),
|
|
)
|
|
)
|
|
|
|
|
|
def _mapping_cache_get(cache_key: str):
|
|
entry = _TRANSFORMER_MAPPING_DECISION_CACHE.pop(cache_key, None)
|
|
if entry is not None:
|
|
# Keep recently used entries warm in insertion order.
|
|
_TRANSFORMER_MAPPING_DECISION_CACHE[cache_key] = entry
|
|
return entry
|
|
|
|
|
|
def _mapping_cache_put(cache_key: str, value: dict) -> None:
|
|
if cache_key in _TRANSFORMER_MAPPING_DECISION_CACHE:
|
|
_TRANSFORMER_MAPPING_DECISION_CACHE.pop(cache_key, None)
|
|
_TRANSFORMER_MAPPING_DECISION_CACHE[cache_key] = value
|
|
while len(_TRANSFORMER_MAPPING_DECISION_CACHE) > _MAX_TRANSFORMER_MAPPING_DECISIONS:
|
|
oldest = next(iter(_TRANSFORMER_MAPPING_DECISION_CACHE))
|
|
_TRANSFORMER_MAPPING_DECISION_CACHE.pop(oldest, None)
|
|
|
|
|
|
def _zimage_persist_converted_cache_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_PERSIST_CONVERTED_CACHE", "1")
|
|
|
|
|
|
def _zimage_persist_converted_cache_dir() -> str:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_PERSIST_CONVERTED_CACHE_DIR", "").strip()
|
|
if raw:
|
|
return os.path.abspath(os.path.expanduser(raw))
|
|
return os.path.join(os.path.expanduser("~"), ".cache", "fooocuspocus", "zimage", "transformer_converted")
|
|
|
|
|
|
def _zimage_persist_converted_max_items() -> int:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_PERSIST_CONVERTED_MAX_ITEMS", "").strip()
|
|
if raw == "":
|
|
return 2
|
|
try:
|
|
value = int(raw)
|
|
if value < 1:
|
|
raise ValueError()
|
|
return min(value, 64)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PERSIST_CONVERTED_MAX_ITEMS",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_PERSIST_CONVERTED_MAX_ITEMS='{raw}'.",
|
|
)
|
|
return 2
|
|
|
|
|
|
def _persistent_transformer_cache_paths(single_file_path: str) -> tuple[str, str]:
|
|
source_id = _single_file_identity(single_file_path)
|
|
key = f"{_PERSISTENT_TRANSFORMER_CACHE_VERSION}|{source_id}"
|
|
digest = hashlib.sha1(key.encode("utf-8", errors="ignore")).hexdigest()
|
|
cache_dir = _zimage_persist_converted_cache_dir()
|
|
return (
|
|
os.path.join(cache_dir, f"{digest}.safetensors"),
|
|
os.path.join(cache_dir, f"{digest}.json"),
|
|
)
|
|
|
|
|
|
def _cleanup_persisted_transformer_cache(cache_dir: str) -> None:
|
|
max_items = _zimage_persist_converted_max_items()
|
|
try:
|
|
names = [n for n in os.listdir(cache_dir) if n.endswith(".safetensors")]
|
|
files = [os.path.join(cache_dir, n) for n in names]
|
|
files = [p for p in files if os.path.isfile(p)]
|
|
if len(files) <= max_items:
|
|
return
|
|
files.sort(key=lambda p: os.path.getmtime(p), reverse=True)
|
|
for stale in files[max_items:]:
|
|
meta = os.path.splitext(stale)[0] + ".json"
|
|
try:
|
|
os.remove(stale)
|
|
except OSError:
|
|
pass
|
|
try:
|
|
if os.path.isfile(meta):
|
|
os.remove(meta)
|
|
except OSError:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _load_persisted_converted_transformer(single_file_path: str) -> Optional[dict]:
|
|
if not _zimage_persist_converted_cache_enabled():
|
|
return None
|
|
tensor_path, meta_path = _persistent_transformer_cache_paths(single_file_path)
|
|
if not os.path.isfile(tensor_path):
|
|
return None
|
|
|
|
try:
|
|
from safetensors.torch import load_file as safetensors_load_file
|
|
|
|
state_dict = safetensors_load_file(tensor_path, device="cpu")
|
|
if not isinstance(state_dict, dict) or len(state_dict) == 0:
|
|
return None
|
|
now = time.time()
|
|
try:
|
|
os.utime(tensor_path, (now, now))
|
|
if os.path.isfile(meta_path):
|
|
os.utime(meta_path, (now, now))
|
|
except OSError:
|
|
pass
|
|
print(
|
|
f"[Z-Image POC] Loaded persisted converted transformer cache "
|
|
f"({len(state_dict)} tensors)."
|
|
)
|
|
return state_dict
|
|
except Exception as e:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PERSIST_CONVERTED_CACHE",
|
|
f"[Z-Image POC] Failed to load persisted converted transformer cache: {e}",
|
|
)
|
|
try:
|
|
os.remove(tensor_path)
|
|
except OSError:
|
|
pass
|
|
try:
|
|
if os.path.isfile(meta_path):
|
|
os.remove(meta_path)
|
|
except OSError:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _save_persisted_converted_transformer(single_file_path: str, state_dict: dict) -> None:
|
|
if not _zimage_persist_converted_cache_enabled() or not state_dict:
|
|
return
|
|
tensor_path, meta_path = _persistent_transformer_cache_paths(single_file_path)
|
|
if os.path.isfile(tensor_path):
|
|
return
|
|
|
|
cache_dir = os.path.dirname(tensor_path)
|
|
try:
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
except OSError as e:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PERSIST_CONVERTED_CACHE",
|
|
f"[Z-Image POC] Failed to prepare persistent cache dir '{cache_dir}': {e}",
|
|
)
|
|
return
|
|
|
|
tensor_tmp = f"{tensor_path}.tmp.{os.getpid()}.{int(time.time() * 1000)}"
|
|
meta_tmp = f"{meta_path}.tmp.{os.getpid()}.{int(time.time() * 1000)}"
|
|
try:
|
|
from safetensors.torch import save_file as safetensors_save_file
|
|
|
|
safetensors_save_file(state_dict, tensor_tmp)
|
|
os.replace(tensor_tmp, tensor_path)
|
|
|
|
meta = {
|
|
"version": _PERSISTENT_TRANSFORMER_CACHE_VERSION,
|
|
"source_identity": _single_file_identity(single_file_path),
|
|
"created_at_unix": int(time.time()),
|
|
"tensor_count": len(state_dict),
|
|
}
|
|
with open(meta_tmp, "w", encoding="utf-8") as f:
|
|
json.dump(meta, f, ensure_ascii=True)
|
|
os.replace(meta_tmp, meta_path)
|
|
_cleanup_persisted_transformer_cache(cache_dir)
|
|
print(
|
|
f"[Z-Image POC] Saved persisted converted transformer cache "
|
|
f"({len(state_dict)} tensors)."
|
|
)
|
|
except Exception as e:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PERSIST_CONVERTED_CACHE",
|
|
f"[Z-Image POC] Failed to save persisted converted transformer cache: {e}",
|
|
)
|
|
try:
|
|
if os.path.isfile(tensor_tmp):
|
|
os.remove(tensor_tmp)
|
|
except OSError:
|
|
pass
|
|
try:
|
|
if os.path.isfile(meta_tmp):
|
|
os.remove(meta_tmp)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _clear_prompt_cache_for_pipeline(cache_key: str) -> None:
|
|
stale = [k for k in _PROMPT_EMBED_CACHE.keys() if isinstance(k, tuple) and k and k[0] == cache_key]
|
|
for k in stale:
|
|
_PROMPT_EMBED_CACHE.pop(k, None)
|
|
|
|
|
|
def _put_prompt_cache(key: tuple, value: tuple) -> None:
|
|
if key in _PROMPT_EMBED_CACHE:
|
|
_PROMPT_EMBED_CACHE.pop(key, None)
|
|
_PROMPT_EMBED_CACHE[key] = value
|
|
while len(_PROMPT_EMBED_CACHE) > _MAX_PROMPT_CACHE_ITEMS:
|
|
first = next(iter(_PROMPT_EMBED_CACHE))
|
|
_PROMPT_EMBED_CACHE.pop(first, None)
|
|
|
|
|
|
def _drop_cache_entry(cache_key: str) -> None:
|
|
_PIPELINE_CACHE.pop(cache_key, None)
|
|
_clear_prompt_cache_for_pipeline(cache_key)
|
|
|
|
|
|
def _cleanup_memory(cuda: bool = True, aggressive: bool = True) -> None:
|
|
if aggressive:
|
|
gc.collect()
|
|
_trim_process_heap()
|
|
if not cuda:
|
|
return
|
|
try:
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
if aggressive and hasattr(torch.cuda, "ipc_collect"):
|
|
torch.cuda.ipc_collect()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _trim_process_heap() -> None:
|
|
# Release glibc free pages back to OS to reduce RSS spikes on model switches.
|
|
if os.name != "posix":
|
|
return
|
|
try:
|
|
libc_name = ctypes.util.find_library("c") or "libc.so.6"
|
|
libc = ctypes.CDLL(libc_name)
|
|
trim = getattr(libc, "malloc_trim", None)
|
|
if callable(trim):
|
|
trim(0)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def clear_runtime_caches(flush_cuda: bool = True, aggressive: bool = True) -> dict:
|
|
released_pipelines = 0
|
|
prompt_cache_entries = len(_PROMPT_EMBED_CACHE)
|
|
move_to_cpu_on_clear = _truthy_env("FOOOCUS_ZIMAGE_CACHE_CLEAR_MOVE_TO_CPU", "0")
|
|
|
|
for _, cached in list(_PIPELINE_CACHE.items()):
|
|
if not isinstance(cached, tuple) or len(cached) < 1:
|
|
continue
|
|
pipeline = cached[0]
|
|
released_pipelines += 1
|
|
try:
|
|
_disable_granular_component_offload(pipeline, target_device="cpu")
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if hasattr(pipeline, "maybe_free_model_hooks"):
|
|
pipeline.maybe_free_model_hooks()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
remove_all_hooks = getattr(pipeline, "remove_all_hooks", None)
|
|
if callable(remove_all_hooks):
|
|
remove_all_hooks()
|
|
except Exception:
|
|
pass
|
|
# Break strong references from pipeline objects to large modules early.
|
|
for attr in (
|
|
"transformer",
|
|
"text_encoder",
|
|
"text_encoder_2",
|
|
"vae",
|
|
"tokenizer",
|
|
"tokenizer_2",
|
|
"scheduler",
|
|
):
|
|
try:
|
|
if hasattr(pipeline, attr):
|
|
setattr(pipeline, attr, None)
|
|
except Exception:
|
|
pass
|
|
# Keep this opt-in: forcing `pipeline.to("cpu")` can create transient RAM spikes
|
|
# with very large text encoders during model switches.
|
|
if move_to_cpu_on_clear:
|
|
try:
|
|
if hasattr(pipeline, "to"):
|
|
pipeline.to("cpu")
|
|
except Exception:
|
|
# Some offload/meta-backed modules cannot be moved directly; hook cleanup above is enough.
|
|
pass
|
|
try:
|
|
del pipeline
|
|
except Exception:
|
|
pass
|
|
|
|
_PIPELINE_CACHE.clear()
|
|
_PROMPT_EMBED_CACHE.clear()
|
|
_cleanup_memory(cuda=flush_cuda, aggressive=aggressive)
|
|
return {"pipelines": released_pipelines, "prompt_cache_entries": prompt_cache_entries}
|
|
|
|
|
|
def _zimage_harsh_cleanup_on_model_change_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_HARSH_CLEANUP_ON_MODEL_CHANGE", "1")
|
|
|
|
|
|
def maybe_cleanup_for_model_change(source_kind: Optional[str], source_path: Optional[str]) -> dict:
|
|
global _LAST_ZIMAGE_SOURCE_SIGNATURE
|
|
|
|
if not source_kind or not source_path:
|
|
return {"changed": False, "cleaned": False, "pipelines": 0, "prompt_cache_entries": 0}
|
|
|
|
signature = (str(source_kind), os.path.abspath(os.path.realpath(source_path)))
|
|
previous_signature = _LAST_ZIMAGE_SOURCE_SIGNATURE
|
|
_LAST_ZIMAGE_SOURCE_SIGNATURE = signature
|
|
|
|
changed = previous_signature is not None and previous_signature != signature
|
|
if not changed:
|
|
return {"changed": False, "cleaned": False, "pipelines": 0, "prompt_cache_entries": 0}
|
|
|
|
if not _zimage_harsh_cleanup_on_model_change_enabled():
|
|
return {"changed": True, "cleaned": False, "pipelines": 0, "prompt_cache_entries": 0}
|
|
|
|
stats = clear_runtime_caches(flush_cuda=True, aggressive=True)
|
|
prev_name = os.path.basename(previous_signature[1]) if previous_signature else "unknown"
|
|
next_name = os.path.basename(signature[1])
|
|
print(
|
|
"[Z-Image POC] Model source changed; applied harsh runtime cleanup "
|
|
f"({prev_name} -> {next_name}): "
|
|
f"pipelines={stats.get('pipelines', 0)}, prompt_cache_entries={stats.get('prompt_cache_entries', 0)}."
|
|
)
|
|
return {
|
|
"changed": True,
|
|
"cleaned": True,
|
|
"pipelines": int(stats.get("pipelines", 0)),
|
|
"prompt_cache_entries": int(stats.get("prompt_cache_entries", 0)),
|
|
}
|
|
|
|
|
|
def _cuda_mem_info_gb() -> tuple[float, float]:
|
|
try:
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
return 0.0, 0.0
|
|
free_bytes, total_bytes = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
return float(free_bytes) / float(1024**3), float(total_bytes) / float(1024**3)
|
|
except Exception:
|
|
return 0.0, 0.0
|
|
|
|
|
|
def _truthy_env(name: str, default: str = "0") -> bool:
|
|
return os.environ.get(name, default).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def _zimage_alt_path_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_ALT_PATH", "1")
|
|
|
|
|
|
def zimage_active_backend_name() -> str:
|
|
return "alternate" if _zimage_alt_path_enabled() else "legacy"
|
|
|
|
|
|
def _pipeline_accepts_kwarg(fn, kwarg_name: str, include_var_kwargs: bool = True) -> bool:
|
|
if fn is None:
|
|
return False
|
|
try:
|
|
signature = inspect.signature(fn)
|
|
except Exception:
|
|
return False
|
|
if kwarg_name in signature.parameters:
|
|
return True
|
|
if include_var_kwargs:
|
|
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in signature.parameters.values())
|
|
return False
|
|
|
|
|
|
def _pipeline_supports_latents(pipeline) -> bool:
|
|
return _pipeline_accepts_kwarg(
|
|
getattr(pipeline, "__call__", None),
|
|
"latents",
|
|
include_var_kwargs=False,
|
|
)
|
|
|
|
|
|
def _resolve_latent_channels(pipeline) -> int:
|
|
candidates = []
|
|
transformer = getattr(pipeline, "transformer", None)
|
|
if transformer is not None:
|
|
tcfg = getattr(transformer, "config", None)
|
|
if tcfg is not None:
|
|
candidates.append(getattr(tcfg, "in_channels", None))
|
|
if isinstance(tcfg, dict):
|
|
candidates.append(tcfg.get("in_channels"))
|
|
|
|
unet = getattr(pipeline, "unet", None)
|
|
if unet is not None:
|
|
ucfg = getattr(unet, "config", None)
|
|
if ucfg is not None:
|
|
candidates.append(getattr(ucfg, "in_channels", None))
|
|
if isinstance(ucfg, dict):
|
|
candidates.append(ucfg.get("in_channels"))
|
|
|
|
for value in candidates:
|
|
try:
|
|
channels = int(value)
|
|
if channels > 0:
|
|
return channels
|
|
except Exception:
|
|
continue
|
|
return 4
|
|
|
|
|
|
def _resolve_latent_spatial_size(pipeline, width: int, height: int) -> tuple[int, int]:
|
|
scale = getattr(pipeline, "vae_scale_factor", 8)
|
|
try:
|
|
scale = int(scale)
|
|
except Exception:
|
|
raise RuntimeError(f"Invalid vae_scale_factor={scale!r} for alternate Z-Image path.")
|
|
if scale <= 0:
|
|
raise RuntimeError(f"Invalid vae_scale_factor={scale!r} for alternate Z-Image path.")
|
|
if width <= 0 or height <= 0:
|
|
raise RuntimeError(f"Invalid generation size {width}x{height} for alternate Z-Image path.")
|
|
if (width % scale) != 0 or (height % scale) != 0:
|
|
raise RuntimeError(
|
|
f"Alternate Z-Image path requires size divisible by vae_scale_factor={scale}, got {width}x{height}."
|
|
)
|
|
latent_w = width // scale
|
|
latent_h = height // scale
|
|
if latent_w <= 0 or latent_h <= 0:
|
|
raise RuntimeError(
|
|
f"Resolved invalid latent size {latent_w}x{latent_h} from image size {width}x{height}."
|
|
)
|
|
return latent_h, latent_w
|
|
|
|
|
|
def _ensure_alt_path_prerequisites(pipeline, width: int, height: int) -> None:
|
|
if not _pipeline_supports_latents(pipeline):
|
|
raise RuntimeError(
|
|
"Z-Image alternate path requires pipeline latents support (missing __call__(..., latents=...))."
|
|
)
|
|
_ = _resolve_latent_channels(pipeline)
|
|
_ = _resolve_latent_spatial_size(pipeline, width=width, height=height)
|
|
|
|
|
|
def _build_latents_from_seeds(
|
|
pipeline,
|
|
seed_list: list[int],
|
|
width: int,
|
|
height: int,
|
|
generator_device: str,
|
|
):
|
|
import torch
|
|
|
|
if not seed_list:
|
|
raise RuntimeError("Alternate Z-Image path requires at least one seed.")
|
|
|
|
latent_h, latent_w = _resolve_latent_spatial_size(pipeline, width=width, height=height)
|
|
channels = _resolve_latent_channels(pipeline)
|
|
transformer = getattr(pipeline, "transformer", None)
|
|
target_dtype = getattr(transformer, "dtype", torch.float32)
|
|
if not isinstance(target_dtype, torch.dtype):
|
|
target_dtype = torch.float32
|
|
if target_dtype in (torch.float16, torch.bfloat16):
|
|
sample_dtype = target_dtype
|
|
else:
|
|
sample_dtype = torch.float32
|
|
|
|
latents = []
|
|
for value in seed_list:
|
|
generator = torch.Generator(device="cpu").manual_seed(int(value))
|
|
latent = torch.randn(
|
|
(1, channels, latent_h, latent_w),
|
|
generator=generator,
|
|
device="cpu",
|
|
dtype=sample_dtype,
|
|
)
|
|
latents.append(latent)
|
|
return torch.cat(latents, dim=0).to(device=generator_device, dtype=sample_dtype)
|
|
|
|
|
|
def _set_generation_random_source(
|
|
*,
|
|
call_kwargs: dict,
|
|
seed_list: list[int],
|
|
pipeline,
|
|
generator_device: str,
|
|
use_alt_path: bool,
|
|
latents_device: Optional[str] = None,
|
|
):
|
|
import torch
|
|
|
|
if use_alt_path:
|
|
call_kwargs.pop("generator", None)
|
|
call_kwargs["latents"] = _build_latents_from_seeds(
|
|
pipeline=pipeline,
|
|
seed_list=seed_list,
|
|
width=int(call_kwargs.get("width", 0)),
|
|
height=int(call_kwargs.get("height", 0)),
|
|
generator_device=(latents_device or generator_device),
|
|
)
|
|
return None
|
|
|
|
call_kwargs.pop("latents", None)
|
|
if len(seed_list) <= 1:
|
|
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
|
|
else:
|
|
generator = [torch.Generator(device=generator_device).manual_seed(s) for s in seed_list]
|
|
call_kwargs["generator"] = generator
|
|
return generator
|
|
|
|
|
|
def _iter_safetensors_files(root: str):
|
|
for dirpath, _, filenames in os.walk(root):
|
|
for filename in filenames:
|
|
if filename.endswith(".safetensors"):
|
|
yield os.path.join(dirpath, filename)
|
|
|
|
|
|
def _zimage_fp16_safety_from_safetensors(path: str) -> Optional[bool]:
|
|
try:
|
|
import torch
|
|
from safetensors import safe_open
|
|
except Exception:
|
|
return None
|
|
|
|
layer_key = re.compile(r"(?:^|\.)layers\.(\d+)\.ffn_norm1\.weight$")
|
|
|
|
try:
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
candidates = []
|
|
for key in f.keys():
|
|
m = layer_key.search(key)
|
|
if m is not None:
|
|
candidates.append((int(m.group(1)), key))
|
|
|
|
if not candidates:
|
|
return None
|
|
|
|
candidates.sort(key=lambda x: x[0])
|
|
# Mirror Neo's heuristic: inspect layer n_layers - 2 (fallback to max layer).
|
|
target_idx = candidates[-1][0] - 1
|
|
selected_key = None
|
|
for idx, key in reversed(candidates):
|
|
if idx == target_idx:
|
|
selected_key = key
|
|
break
|
|
if selected_key is None:
|
|
selected_key = candidates[-1][1]
|
|
|
|
weight = f.get_tensor(selected_key)
|
|
except Exception:
|
|
return None
|
|
|
|
try:
|
|
std = torch.std(weight, unbiased=False).item()
|
|
except Exception:
|
|
return None
|
|
|
|
return bool(std < 0.42)
|
|
|
|
|
|
def _detect_zimage_allow_fp16(source_kind: str, source_path: str) -> Optional[bool]:
|
|
explicit = os.environ.get("FOOOCUS_ZIMAGE_ALLOW_FP16", "").strip().lower()
|
|
if explicit in ("1", "true", "yes", "on"):
|
|
return True
|
|
if explicit in ("0", "false", "no", "off"):
|
|
return False
|
|
|
|
candidates = []
|
|
if source_kind == "single_file" and source_path.endswith(".safetensors"):
|
|
candidates = [source_path]
|
|
elif source_kind == "directory":
|
|
search_root = os.path.join(source_path, "transformer")
|
|
if not os.path.isdir(search_root):
|
|
search_root = source_path
|
|
|
|
candidates = sorted(
|
|
_iter_safetensors_files(search_root),
|
|
key=lambda p: (
|
|
0 if "transformer" in p.lower() else 1,
|
|
0 if os.path.basename(p).startswith("diffusion_pytorch_model") else 1,
|
|
-os.path.getsize(p) if os.path.isfile(p) else 0,
|
|
),
|
|
)
|
|
|
|
for candidate in candidates:
|
|
safe = _zimage_fp16_safety_from_safetensors(candidate)
|
|
if safe is not None:
|
|
status = "safe" if safe else "unsafe"
|
|
print(f"[Z-Image POC] Detected fp16 as {status} from: {os.path.basename(candidate)}")
|
|
return safe
|
|
|
|
return None
|
|
|
|
|
|
def _zimage_perf_profile() -> str:
|
|
profile = os.environ.get("FOOOCUS_ZIMAGE_PERF_PROFILE", "safe").strip().lower()
|
|
if profile not in ("safe", "balanced", "speed"):
|
|
return "safe"
|
|
return profile
|
|
|
|
|
|
def _zimage_compute_dtype_mode() -> str:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_COMPUTE_DTYPE", "").strip().lower()
|
|
if raw == "":
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_DTYPE", "auto").strip().lower()
|
|
|
|
aliases = {
|
|
"auto": "auto",
|
|
"bf16": "bf16",
|
|
"bfloat16": "bf16",
|
|
"fp16": "fp16",
|
|
"float16": "fp16",
|
|
"half": "fp16",
|
|
"fp32": "fp32",
|
|
"float32": "fp32",
|
|
"full": "fp32",
|
|
}
|
|
mode = aliases.get(raw, None)
|
|
if mode is None:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_COMPUTE_DTYPE",
|
|
f"[Z-Image POC] Ignoring invalid compute dtype '{raw}'. Expected: auto|bf16|fp16|fp32.",
|
|
)
|
|
return "auto"
|
|
return mode
|
|
|
|
|
|
def _zimage_strict_fp16_mode() -> bool:
|
|
return _zimage_compute_dtype_mode() == "fp16"
|
|
|
|
|
|
def _zimage_fp16_quant_accum_mode() -> str:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_COMFY_RUNTIME_FP16_ACCUM", "auto").strip().lower()
|
|
aliases = {
|
|
"auto": "auto",
|
|
"fp16": "fp16",
|
|
"float16": "fp16",
|
|
"half": "fp16",
|
|
"bf16": "bf16",
|
|
"bfloat16": "bf16",
|
|
"fp32": "fp32",
|
|
"float32": "fp32",
|
|
"full": "fp32",
|
|
}
|
|
mode = aliases.get(raw, None)
|
|
if mode is None:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_COMFY_RUNTIME_FP16_ACCUM",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_COMFY_RUNTIME_FP16_ACCUM='{raw}'. "
|
|
"Expected: auto|fp16|bf16|fp32.",
|
|
)
|
|
return "auto"
|
|
return mode
|
|
|
|
|
|
def _zimage_prewarm_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_PREWARM", "0")
|
|
|
|
|
|
def _zimage_prewarm_steps() -> int:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_PREWARM_STEPS", "").strip()
|
|
if raw == "":
|
|
return 1
|
|
try:
|
|
return max(1, min(int(raw), 8))
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PREWARM_STEPS",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_PREWARM_STEPS='{raw}'.",
|
|
)
|
|
return 1
|
|
|
|
|
|
def _zimage_prewarm_size(default_width: int = 832, default_height: int = 1216) -> tuple[int, int]:
|
|
raw_w = os.environ.get("FOOOCUS_ZIMAGE_PREWARM_WIDTH", "").strip()
|
|
raw_h = os.environ.get("FOOOCUS_ZIMAGE_PREWARM_HEIGHT", "").strip()
|
|
width = default_width
|
|
height = default_height
|
|
if raw_w:
|
|
try:
|
|
width = max(256, int(raw_w))
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PREWARM_WIDTH",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_PREWARM_WIDTH='{raw_w}'.",
|
|
)
|
|
if raw_h:
|
|
try:
|
|
height = max(256, int(raw_h))
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_PREWARM_HEIGHT",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_PREWARM_HEIGHT='{raw_h}'.",
|
|
)
|
|
width = int(width // 64) * 64
|
|
height = int(height // 64) * 64
|
|
width = max(width, 256)
|
|
height = max(height, 256)
|
|
return width, height
|
|
|
|
|
|
def _zimage_black_image_retry_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_BLACK_IMAGE_RETRY", "1")
|
|
|
|
|
|
def _zimage_black_image_max_value() -> int:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_BLACK_IMAGE_MAX_VALUE", "").strip()
|
|
if raw == "":
|
|
return 8
|
|
try:
|
|
value = int(raw)
|
|
if value < 0:
|
|
raise ValueError()
|
|
return min(value, 32)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_BLACK_IMAGE_MAX_VALUE",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_BLACK_IMAGE_MAX_VALUE='{raw}'.",
|
|
)
|
|
return 8
|
|
|
|
|
|
def _zimage_black_image_mean_threshold() -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_BLACK_IMAGE_MEAN_THRESHOLD", "").strip()
|
|
if raw == "":
|
|
return 2.0
|
|
try:
|
|
value = float(raw)
|
|
if value < 0.0:
|
|
raise ValueError()
|
|
return min(value, 10.0)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_BLACK_IMAGE_MEAN_THRESHOLD",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_BLACK_IMAGE_MEAN_THRESHOLD='{raw}'.",
|
|
)
|
|
return 2.0
|
|
|
|
|
|
def _analyze_black_image(image) -> Optional[dict]:
|
|
try:
|
|
from PIL import ImageStat
|
|
except Exception:
|
|
return None
|
|
|
|
try:
|
|
rgb = image.convert("RGB")
|
|
extrema = rgb.getextrema()
|
|
stats = ImageStat.Stat(rgb)
|
|
max_value = max(ch_max for _, ch_max in extrema)
|
|
mean_value = float(sum(stats.mean) / max(len(stats.mean), 1))
|
|
std_value = float(sum(stats.stddev) / max(len(stats.stddev), 1))
|
|
return {
|
|
"max": float(max_value),
|
|
"mean": mean_value,
|
|
"std": std_value,
|
|
}
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _is_suspected_black_image(image) -> tuple[bool, Optional[dict]]:
|
|
info = _analyze_black_image(image)
|
|
if info is None:
|
|
return False, None
|
|
max_cap = float(_zimage_black_image_max_value())
|
|
mean_cap = float(_zimage_black_image_mean_threshold())
|
|
is_black = info["max"] <= max_cap and info["mean"] <= mean_cap
|
|
return is_black, info
|
|
|
|
|
|
def _retune_runtime_quant_modules_dtype(root_module, dtype) -> int:
|
|
if root_module is None:
|
|
return 0
|
|
changed = 0
|
|
try:
|
|
modules_iter = root_module.modules()
|
|
except Exception:
|
|
modules_iter = []
|
|
for module in modules_iter:
|
|
if not hasattr(module, "compute_dtype"):
|
|
continue
|
|
if not hasattr(module, "quant_format"):
|
|
continue
|
|
old_dtype = getattr(module, "compute_dtype", None)
|
|
try:
|
|
module.compute_dtype = dtype
|
|
if old_dtype != dtype:
|
|
changed += 1
|
|
except Exception:
|
|
continue
|
|
clear_cache = getattr(module, "_clear_cache", None)
|
|
if callable(clear_cache):
|
|
try:
|
|
clear_cache()
|
|
except Exception:
|
|
pass
|
|
return changed
|
|
|
|
|
|
def _warn_once_env(key: str, message: str) -> None:
|
|
token = f"{key}:{message}"
|
|
if token in _ENV_WARNING_ONCE:
|
|
return
|
|
_ENV_WARNING_ONCE.add(token)
|
|
print(message)
|
|
|
|
|
|
def _zimage_forced_memory_mode() -> Optional[str]:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_FORCE_MEMORY_MODE", "").strip().lower()
|
|
if raw == "":
|
|
return None
|
|
|
|
aliases = {
|
|
"full_gpu": "full_gpu",
|
|
"full": "full_gpu",
|
|
"gpu": "full_gpu",
|
|
"model_offload": "model_offload",
|
|
"model": "model_offload",
|
|
"sequential_offload": "sequential_offload",
|
|
"sequential": "sequential_offload",
|
|
}
|
|
forced = aliases.get(raw, None)
|
|
if forced is None:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_FORCE_MEMORY_MODE",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_FORCE_MEMORY_MODE='{raw}'. "
|
|
"Expected one of: full_gpu, model_offload, sequential_offload.",
|
|
)
|
|
return None
|
|
|
|
return forced
|
|
|
|
|
|
def _zimage_granular_offload_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_GRANULAR_OFFLOAD", "1")
|
|
|
|
|
|
def _zimage_deep_patcher_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_DEEP_PATCHER", "1")
|
|
|
|
|
|
def _zimage_deep_patcher_min_module_mb() -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_DEEP_PATCHER_MIN_MODULE_MB", "").strip()
|
|
if raw == "":
|
|
return 0.0
|
|
try:
|
|
value = float(raw)
|
|
if value < 0.0:
|
|
raise ValueError()
|
|
return min(value, 512.0)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_DEEP_PATCHER_MIN_MODULE_MB",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_DEEP_PATCHER_MIN_MODULE_MB='{raw}'.",
|
|
)
|
|
return 0.0
|
|
|
|
|
|
def _zimage_stage_timers_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_STAGE_TIMERS", "0")
|
|
|
|
|
|
def _zimage_allow_quality_fallback() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_ALLOW_QUALITY_FALLBACK", "0")
|
|
|
|
|
|
def _zimage_preemptive_cuda_cleanup_enabled() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_PREEMPTIVE_CUDA_CLEANUP", "1")
|
|
|
|
|
|
def _zimage_preemptive_cuda_cleanup_aggressive() -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_PREEMPTIVE_CUDA_CLEANUP_AGGRESSIVE", "0")
|
|
|
|
|
|
def _zimage_reserved_vram_gb(total_vram_gb: float = 0.0) -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_RESERVE_VRAM_GB", "").strip()
|
|
if raw:
|
|
try:
|
|
reserve = max(0.0, float(raw))
|
|
return reserve
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_RESERVE_VRAM_GB",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_RESERVE_VRAM_GB='{raw}'.",
|
|
)
|
|
|
|
# Mirror Comfy defaults.
|
|
if os.name == "nt":
|
|
reserve = 0.6
|
|
if total_vram_gb >= 15.0:
|
|
reserve += 0.1
|
|
return reserve
|
|
return 0.4
|
|
|
|
|
|
def _zimage_model_offload_min_gap_gb() -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_MODEL_OFFLOAD_MIN_GAP_GB", "").strip()
|
|
if raw:
|
|
try:
|
|
return max(0.0, float(raw))
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_MODEL_OFFLOAD_MIN_GAP_GB",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_MODEL_OFFLOAD_MIN_GAP_GB='{raw}'.",
|
|
)
|
|
# Conservative default for turbo on 10-12GB class cards.
|
|
return 1.8
|
|
|
|
|
|
def _zimage_vram_estimate_scale() -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_VRAM_ESTIMATE_SCALE", "").strip()
|
|
if raw == "":
|
|
return 1.0
|
|
try:
|
|
value = float(raw)
|
|
if value <= 0.0:
|
|
raise ValueError()
|
|
return min(value, 4.0)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_VRAM_ESTIMATE_SCALE",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_VRAM_ESTIMATE_SCALE='{raw}'.",
|
|
)
|
|
return 1.0
|
|
|
|
|
|
def _system_ram_info_gb() -> tuple[float, float]:
|
|
# Returns (available_gb, total_gb). Best effort across environments.
|
|
try:
|
|
import psutil # type: ignore
|
|
|
|
vm = psutil.virtual_memory()
|
|
return float(vm.available) / float(1024**3), float(vm.total) / float(1024**3)
|
|
except Exception:
|
|
pass
|
|
|
|
# Linux fallback without psutil.
|
|
if os.name == "posix":
|
|
try:
|
|
total_kb = None
|
|
avail_kb = None
|
|
with open("/proc/meminfo", "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if line.startswith("MemTotal:"):
|
|
total_kb = int(line.split()[1])
|
|
elif line.startswith("MemAvailable:"):
|
|
avail_kb = int(line.split()[1])
|
|
if total_kb is not None and avail_kb is not None:
|
|
break
|
|
if total_kb is not None and avail_kb is not None:
|
|
return float(avail_kb) / float(1024**2), float(total_kb) / float(1024**2)
|
|
except Exception:
|
|
pass
|
|
|
|
return 0.0, 0.0
|
|
|
|
|
|
def _zimage_system_ram_reserve_gb() -> float:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_SYSTEM_RAM_RESERVE_GB", "").strip()
|
|
if raw == "":
|
|
# Conservative floor so offload doesn't drive desktop/system into swap thrash.
|
|
return 6.0
|
|
try:
|
|
value = float(raw)
|
|
if value < 0.0:
|
|
raise ValueError()
|
|
return min(value, 64.0)
|
|
except Exception:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_SYSTEM_RAM_RESERVE_GB",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_SYSTEM_RAM_RESERVE_GB='{raw}'.",
|
|
)
|
|
return 6.0
|
|
|
|
|
|
def _format_timing_ms(value: Optional[float]) -> str:
|
|
if value is None:
|
|
return "n/a"
|
|
return f"{value * 1000.0:.1f}ms"
|
|
|
|
|
|
def _zimage_xformers_mode() -> str:
|
|
value = os.environ.get("FOOOCUS_ZIMAGE_XFORMERS", "on").strip().lower()
|
|
if value in ("1", "true", "yes", "on", "force"):
|
|
return "on"
|
|
if value in ("0", "false", "no", "off", "disable"):
|
|
return "off"
|
|
return "auto"
|
|
|
|
|
|
def _zimage_attention_backend_mode() -> str:
|
|
value = os.environ.get("FOOOCUS_ZIMAGE_ATTN_BACKEND", "auto").strip().lower()
|
|
if value in ("", "auto", "default"):
|
|
return "auto"
|
|
if value in ("flash", "flash2", "flash-attn", "flash_attention", "flash_attention_2", "fa2"):
|
|
return "flash"
|
|
if value in ("sdpa", "torch", "torch_sdpa"):
|
|
return "sdpa"
|
|
if value in ("xformers", "xformer", "xf"):
|
|
return "xformers"
|
|
if value in ("native", "none", "off", "disable"):
|
|
return "native"
|
|
return "auto"
|
|
|
|
|
|
def _zimage_attention_backend_candidates(mode: str, allow_xformers: bool = True) -> list[str]:
|
|
if mode == "native":
|
|
return ["native"]
|
|
if mode == "sdpa":
|
|
return ["flash_attention_2", "flash_attention", "sdpa", "native"]
|
|
if mode == "flash":
|
|
return ["flash_attention_2", "flash_attention", "flash", "sdpa", "native"]
|
|
if mode == "xformers":
|
|
return ["xformers", "native"]
|
|
|
|
# auto
|
|
candidates = ["flash_attention_2", "flash_attention", "flash", "sdpa"]
|
|
if allow_xformers:
|
|
candidates.append("xformers")
|
|
candidates.append("native")
|
|
return candidates
|
|
|
|
|
|
def _is_flash_attention_backend(name: str) -> bool:
|
|
lowered = str(name).strip().lower()
|
|
return lowered.startswith("flash")
|
|
|
|
|
|
def _parse_backend_names_from_error(message: str) -> list[str]:
|
|
text = str(message or "")
|
|
m = re.search(r"must be one of the following:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
|
if not m:
|
|
return []
|
|
raw = m.group(1).strip().splitlines()[0]
|
|
items = []
|
|
for part in raw.split(","):
|
|
item = part.strip().strip("`'\"")
|
|
if item:
|
|
items.append(item)
|
|
return items
|
|
|
|
|
|
def _discover_transformer_attention_backends(transformer) -> list[str]:
|
|
attrs = (
|
|
"get_supported_attention_backends",
|
|
"get_attention_backends",
|
|
"list_attention_backends",
|
|
"available_attention_backends",
|
|
"attention_backends",
|
|
)
|
|
for attr in attrs:
|
|
obj = getattr(transformer, attr, None)
|
|
if obj is None:
|
|
continue
|
|
try:
|
|
value = obj() if callable(obj) else obj
|
|
except Exception:
|
|
continue
|
|
if isinstance(value, dict):
|
|
values = list(value.keys())
|
|
elif isinstance(value, (list, tuple, set)):
|
|
values = list(value)
|
|
else:
|
|
continue
|
|
result = []
|
|
for v in values:
|
|
s = str(v).strip()
|
|
if s:
|
|
result.append(s)
|
|
if result:
|
|
return result
|
|
return []
|
|
|
|
|
|
def _expand_attention_backend_alias(name: str) -> list[str]:
|
|
key = str(name).strip().lower()
|
|
alias_map = {
|
|
"flash_attention_2": [
|
|
"flash_attention_2",
|
|
"flash",
|
|
"flash_hub",
|
|
"flash_varlen",
|
|
"flash_varlen_hub",
|
|
"_flash_3",
|
|
"_flash_varlen_3",
|
|
"_flash_3_hub",
|
|
"_native_flash",
|
|
],
|
|
"flash_attention": [
|
|
"flash_attention",
|
|
"flash",
|
|
"flash_hub",
|
|
"flash_varlen",
|
|
"flash_varlen_hub",
|
|
"_native_flash",
|
|
],
|
|
"flash": [
|
|
"flash",
|
|
"flash_hub",
|
|
"flash_varlen",
|
|
"flash_varlen_hub",
|
|
"_flash_3",
|
|
"_flash_varlen_3",
|
|
"_flash_3_hub",
|
|
"_native_flash",
|
|
],
|
|
"sdpa": [
|
|
"sdpa",
|
|
"_native_efficient",
|
|
"_native_math",
|
|
"_native_flash",
|
|
"_native_cudnn",
|
|
"native",
|
|
],
|
|
"native": [
|
|
"native",
|
|
"_native_math",
|
|
"_native_efficient",
|
|
"_native_flash",
|
|
"_native_cudnn",
|
|
],
|
|
"xformers": ["xformers"],
|
|
}
|
|
expanded = alias_map.get(key, [key])
|
|
return list(dict.fromkeys(expanded))
|
|
|
|
|
|
def _remap_attention_backend_candidates(candidates: list[str], available: list[str]) -> list[str]:
|
|
if not available:
|
|
return candidates
|
|
available_map = {str(x).strip().lower(): str(x).strip() for x in available if str(x).strip()}
|
|
remapped = []
|
|
for candidate in candidates:
|
|
picked = None
|
|
for alias in _expand_attention_backend_alias(candidate):
|
|
if alias in available_map:
|
|
picked = available_map[alias]
|
|
break
|
|
if picked is None:
|
|
picked = candidate
|
|
remapped.append(picked)
|
|
# Keep order while deduping.
|
|
return list(dict.fromkeys(remapped))
|
|
|
|
|
|
def _round_up_to_supported_seq(value: int, max_cap: int) -> int:
|
|
buckets = [32, 64, 96, 128, 160, 192, 256, 384, 512]
|
|
cap = max(32, int(max_cap))
|
|
for bucket in buckets:
|
|
if bucket >= value and bucket <= cap:
|
|
return bucket
|
|
return cap
|
|
|
|
|
|
def _normalize_prompt_text_for_count(text) -> str:
|
|
if text is None:
|
|
return ""
|
|
if isinstance(text, str):
|
|
return text
|
|
if isinstance(text, (list, tuple)):
|
|
for item in text:
|
|
if isinstance(item, str) and item.strip():
|
|
return item
|
|
if text:
|
|
return str(text[0])
|
|
return ""
|
|
return str(text)
|
|
|
|
|
|
def _count_tokens_from_input_ids(ids) -> int:
|
|
if ids is None:
|
|
return 0
|
|
|
|
shape = getattr(ids, "shape", None)
|
|
if shape is not None:
|
|
try:
|
|
if len(shape) >= 2:
|
|
return max(0, int(shape[-1]))
|
|
if len(shape) == 1:
|
|
return max(0, int(shape[0]))
|
|
except Exception:
|
|
pass
|
|
|
|
if hasattr(ids, "tolist"):
|
|
try:
|
|
return _count_tokens_from_input_ids(ids.tolist())
|
|
except Exception:
|
|
pass
|
|
|
|
if isinstance(ids, (list, tuple)):
|
|
if not ids:
|
|
return 0
|
|
|
|
first = ids[0]
|
|
if isinstance(first, (list, tuple)):
|
|
return max(0, len(first))
|
|
|
|
nested = _count_tokens_from_input_ids(first)
|
|
if nested > 0:
|
|
return nested
|
|
return max(0, len(ids))
|
|
|
|
try:
|
|
return max(0, len(ids))
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
def _estimate_prompt_token_count(pipeline, text: str) -> int:
|
|
text = _normalize_prompt_text_for_count(text)
|
|
tokenizer = getattr(pipeline, "tokenizer", None)
|
|
if tokenizer is not None:
|
|
try:
|
|
encoded = tokenizer(
|
|
text or "",
|
|
add_special_tokens=True,
|
|
truncation=False,
|
|
return_attention_mask=False,
|
|
return_tensors=None,
|
|
)
|
|
ids = encoded.get("input_ids", []) if isinstance(encoded, dict) else []
|
|
token_count = _count_tokens_from_input_ids(ids)
|
|
if token_count > 0:
|
|
return token_count
|
|
except Exception:
|
|
pass
|
|
# Fallback heuristic if tokenizer is not available.
|
|
words = len((text or "").strip().split())
|
|
return max(1, int(words * 1.6) + 4)
|
|
|
|
|
|
def _compute_auto_max_sequence_length(
|
|
pipeline,
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
use_cfg: bool,
|
|
hard_cap: int,
|
|
) -> int:
|
|
pos_tokens = _estimate_prompt_token_count(pipeline, prompt)
|
|
pos_need = max(64, pos_tokens + 24)
|
|
|
|
neg_tokens = 0
|
|
neg_need = 32
|
|
if use_cfg:
|
|
neg_tokens = _estimate_prompt_token_count(pipeline, negative_prompt)
|
|
neg_need = max(32, neg_tokens + 8)
|
|
|
|
target = max(pos_need, neg_need if use_cfg else 0)
|
|
chosen = _round_up_to_supported_seq(target, hard_cap)
|
|
print(
|
|
f"[Z-Image POC] Auto max_seq from tokens: pos={pos_tokens}, neg={neg_tokens}, "
|
|
f"use_cfg={use_cfg} -> {chosen} (cap={hard_cap})"
|
|
)
|
|
return chosen
|
|
|
|
|
|
def _universal_zimage_root() -> str:
|
|
return os.path.join(_project_root(), "models", "zimage")
|
|
|
|
|
|
def _is_valid_component_dir(path: str, component_name: str) -> bool:
|
|
if not os.path.isdir(path):
|
|
return False
|
|
if os.path.isfile(os.path.join(path, "config.json")):
|
|
return True
|
|
if component_name == "vae":
|
|
for filename in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
|
|
if os.path.isfile(os.path.join(path, filename)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _iter_component_dirs(root: str, component_name: str):
|
|
if not root or not os.path.isdir(root):
|
|
return
|
|
direct = os.path.join(root, component_name)
|
|
if os.path.isdir(direct):
|
|
yield direct
|
|
try:
|
|
with os.scandir(root) as entries:
|
|
for entry in entries:
|
|
if not entry.is_dir():
|
|
continue
|
|
nested = os.path.join(entry.path, component_name)
|
|
if os.path.isdir(nested):
|
|
yield nested
|
|
except Exception:
|
|
return
|
|
|
|
|
|
def list_zimage_component_entries(component_name: str, checkpoint_folders: list[str]) -> list[str]:
|
|
if component_name not in ("text_encoder", "vae"):
|
|
return []
|
|
roots = [_universal_zimage_root()] + list(checkpoint_folders or [])
|
|
for folder in list(checkpoint_folders or []):
|
|
try:
|
|
parent = os.path.abspath(os.path.dirname(folder))
|
|
roots.append(os.path.join(parent, "zimage"))
|
|
except Exception:
|
|
pass
|
|
results = []
|
|
seen = set()
|
|
for root in roots:
|
|
for candidate in _iter_component_dirs(root, component_name):
|
|
normalized = os.path.abspath(candidate)
|
|
if normalized in seen:
|
|
continue
|
|
if not _is_valid_component_dir(normalized, component_name):
|
|
continue
|
|
seen.add(normalized)
|
|
results.append(normalized)
|
|
return sorted(results, key=str.casefold)
|
|
|
|
|
|
def _human_component_path(path: str, checkpoint_folders: list[str]) -> str:
|
|
path_abs = os.path.abspath(path)
|
|
roots = [_universal_zimage_root()] + list(checkpoint_folders or [])
|
|
for root in roots:
|
|
if not root:
|
|
continue
|
|
root_abs = os.path.abspath(root)
|
|
prefix = root_abs + os.sep
|
|
if path_abs.startswith(prefix):
|
|
return os.path.relpath(path_abs, root_abs)
|
|
return path_abs
|
|
|
|
|
|
def _component_weight_files(component_dir: str) -> list[str]:
|
|
files = []
|
|
try:
|
|
with os.scandir(component_dir) as entries:
|
|
for entry in entries:
|
|
if not entry.is_file():
|
|
continue
|
|
name = entry.name.lower()
|
|
if name.endswith(".safetensors") or name.endswith(".bin") or name.endswith(".pt") or name.endswith(".pth"):
|
|
files.append(entry.name)
|
|
except Exception:
|
|
return []
|
|
return sorted(files, key=str.casefold)
|
|
|
|
|
|
def _zimage_component_choice_pairs(component_name: str, checkpoint_folders: list[str]) -> list[tuple[str, str]]:
|
|
entries = list_zimage_component_entries(component_name, checkpoint_folders)
|
|
pairs = []
|
|
label_counts = {}
|
|
|
|
def _push_choice(raw_label: str, raw_value: str):
|
|
label = raw_label
|
|
if label in label_counts:
|
|
label_counts[label] += 1
|
|
label = f"{label} ({label_counts[label]})"
|
|
else:
|
|
label_counts[label] = 1
|
|
pairs.append((label, raw_value))
|
|
|
|
for entry in entries:
|
|
entry = os.path.abspath(entry)
|
|
rel = _human_component_path(entry, checkpoint_folders)
|
|
weight_files = _component_weight_files(entry)
|
|
_push_choice(f"{rel} [default]", entry)
|
|
if component_name == "vae":
|
|
# Single-file VAE overrides are often architecture-incompatible; keep UI on folder/default choices.
|
|
continue
|
|
for filename in weight_files:
|
|
_push_choice(f"{rel} :: {filename}", os.path.abspath(os.path.join(entry, filename)))
|
|
return pairs
|
|
|
|
|
|
def list_zimage_component_choices(component_name: str, checkpoint_folders: list[str]) -> list[str]:
|
|
return [label for label, _ in _zimage_component_choice_pairs(component_name, checkpoint_folders)]
|
|
|
|
|
|
def resolve_zimage_component_path(
|
|
selection: Optional[str],
|
|
component_name: str,
|
|
checkpoint_folders: list[str],
|
|
) -> Optional[str]:
|
|
selected = (selection or "").strip()
|
|
if not selected or selected == ZIMAGE_COMPONENT_AUTO:
|
|
return None
|
|
|
|
if os.path.isabs(selected):
|
|
if os.path.isfile(selected):
|
|
return os.path.abspath(selected)
|
|
return os.path.abspath(selected) if _is_valid_component_dir(selected, component_name) else None
|
|
|
|
for label, path in _zimage_component_choice_pairs(component_name, checkpoint_folders):
|
|
if selected == label:
|
|
return path
|
|
|
|
for candidate in list_zimage_component_entries(component_name, checkpoint_folders):
|
|
if candidate == selected:
|
|
return candidate
|
|
if os.path.basename(candidate) == selected:
|
|
return candidate
|
|
parent = os.path.basename(os.path.dirname(candidate))
|
|
if selected == f"{parent}/{component_name}" or selected == f"{parent}\\{component_name}":
|
|
return candidate
|
|
return None
|
|
|
|
|
|
def _forced_zimage_flavor() -> Optional[str]:
|
|
raw = os.environ.get("FOOOCUS_ZIMAGE_FLAVOR", "").strip().lower()
|
|
if raw in ("turbo", "standard"):
|
|
return raw
|
|
if raw:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_FLAVOR",
|
|
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_FLAVOR='{raw}'. Expected 'turbo' or 'standard'.",
|
|
)
|
|
return None
|
|
|
|
|
|
def detect_zimage_flavor(name: str) -> str:
|
|
forced = _forced_zimage_flavor()
|
|
if forced is not None:
|
|
return forced
|
|
|
|
# Forge Neo behavior: ZImage checkpoints resolve to Turbo flavor.
|
|
return "turbo"
|
|
|
|
|
|
def _detect_zimage_flavor_from_source(source_kind: str, source_path: str, fallback: str = "turbo") -> str:
|
|
forced = _forced_zimage_flavor()
|
|
if forced is not None:
|
|
return forced
|
|
|
|
# Keep function signature stable, but align runtime behavior to Forge Neo:
|
|
# one ZImage class -> Tongyi-MAI/Z-Image-Turbo repo.
|
|
_ = (source_kind, source_path, fallback)
|
|
return "turbo"
|
|
|
|
|
|
def _repo_for_flavor(flavor: str) -> str:
|
|
if flavor != "turbo":
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_FLAVOR_NON_TURBO",
|
|
"[Z-Image POC] Non-turbo Z-Image flavor requested; using Turbo repo to match Forge Neo behavior.",
|
|
)
|
|
return "Tongyi-MAI/Z-Image-Turbo"
|
|
|
|
|
|
def _sha256_file(path: str) -> str:
|
|
h = hashlib.sha256()
|
|
with open(path, "rb") as f:
|
|
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
h.update(chunk)
|
|
return h.hexdigest()
|
|
|
|
|
|
def _read_json(path: str) -> dict:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def is_zimage_model_directory(path: str) -> bool:
|
|
if not os.path.isdir(path):
|
|
return False
|
|
|
|
model_index_path = os.path.join(path, "model_index.json")
|
|
if os.path.isfile(model_index_path):
|
|
try:
|
|
model_index = _read_json(model_index_path)
|
|
class_name = str(model_index.get("_class_name", ""))
|
|
if "zimage" in class_name.lower():
|
|
return True
|
|
model_entries = str(model_index.get("transformer", ""))
|
|
if "zimage" in model_entries.lower():
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
transformer_config_path = os.path.join(path, "transformer", "config.json")
|
|
if os.path.isfile(transformer_config_path):
|
|
try:
|
|
transformer_config = _read_json(transformer_config_path)
|
|
class_name = str(transformer_config.get("_class_name", ""))
|
|
dim = int(transformer_config.get("dim", -1))
|
|
return "zimage" in class_name.lower() or dim == 3840
|
|
except Exception:
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def list_zimage_model_entries(checkpoint_folders: list[str]) -> list[str]:
|
|
entries = []
|
|
for folder in checkpoint_folders:
|
|
if not os.path.isdir(folder):
|
|
continue
|
|
for root, dirs, _ in os.walk(folder, topdown=True):
|
|
relative_root = os.path.relpath(root, folder)
|
|
relative_root = "" if relative_root == "." else relative_root
|
|
if is_zimage_model_directory(root):
|
|
if relative_root:
|
|
entries.append(relative_root)
|
|
dirs[:] = []
|
|
continue
|
|
return sorted(set(entries), key=str.casefold)
|
|
|
|
|
|
def resolve_zimage_model_path(name: str, checkpoint_folders: list[str]) -> Optional[str]:
|
|
if not isinstance(name, str) or not name.strip():
|
|
return None
|
|
|
|
if os.path.isabs(name) and is_zimage_model_directory(name):
|
|
return name
|
|
|
|
for folder in checkpoint_folders:
|
|
candidate = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
|
|
if is_zimage_model_directory(candidate):
|
|
return candidate
|
|
|
|
return None
|
|
|
|
|
|
def _resolve_named_path(name: str, checkpoint_folders: list[str]) -> Optional[str]:
|
|
if not isinstance(name, str) or not name.strip():
|
|
return None
|
|
|
|
if os.path.isabs(name) and os.path.exists(name):
|
|
return os.path.abspath(os.path.realpath(name))
|
|
|
|
for folder in checkpoint_folders:
|
|
candidate = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
|
|
if os.path.exists(candidate):
|
|
return candidate
|
|
|
|
# Fallback: match by basename anywhere inside configured checkpoint folders.
|
|
target = os.path.basename(name).casefold()
|
|
for folder in checkpoint_folders:
|
|
if not os.path.isdir(folder):
|
|
continue
|
|
for root, _, files in os.walk(folder, topdown=True):
|
|
for file_name in files:
|
|
if file_name.casefold() == target:
|
|
return os.path.abspath(os.path.realpath(os.path.join(root, file_name)))
|
|
for root, dirs, _ in os.walk(folder, topdown=True):
|
|
for dir_name in dirs:
|
|
if dir_name.casefold() == target:
|
|
return os.path.abspath(os.path.realpath(os.path.join(root, dir_name)))
|
|
|
|
return None
|
|
|
|
|
|
def _is_likely_zimage_safetensors(path: str) -> bool:
|
|
if not os.path.isfile(path):
|
|
return False
|
|
if not path.lower().endswith(".safetensors"):
|
|
return False
|
|
|
|
try:
|
|
from safetensors import safe_open
|
|
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
keys = set(f.keys())
|
|
|
|
# Many checkpoints are saved with a transformer prefix
|
|
# (e.g. "model.diffusion_model.") while others are bare keys.
|
|
transformer_prefixes = ("model.diffusion_model.", "diffusion_model.", "transformer.")
|
|
|
|
def _strip_prefix(key: str) -> str:
|
|
for prefix in transformer_prefixes:
|
|
if key.startswith(prefix):
|
|
return key[len(prefix):]
|
|
return key
|
|
|
|
normalized_keys = {_strip_prefix(k) for k in keys}
|
|
|
|
if any(k.startswith("text_encoders.qwen3_4b.") for k in keys) or any(
|
|
k.startswith("text_encoders.qwen3_4b.") for k in normalized_keys
|
|
):
|
|
return True
|
|
|
|
cap_weight_key = None
|
|
if "cap_embedder.1.weight" in keys:
|
|
cap_weight_key = "cap_embedder.1.weight"
|
|
else:
|
|
for prefix in transformer_prefixes:
|
|
candidate = f"{prefix}cap_embedder.1.weight"
|
|
if candidate in keys:
|
|
cap_weight_key = candidate
|
|
break
|
|
if cap_weight_key is not None:
|
|
cap_shape = tuple(f.get_tensor(cap_weight_key).shape)
|
|
# Forge-style detection: Lumina2 backbone with dim=3840 is Z-Image.
|
|
if len(cap_shape) >= 1 and cap_shape[0] == 3840:
|
|
return True
|
|
|
|
has_lumina_backbone = any(k.startswith("layers.0.attention.") for k in normalized_keys)
|
|
has_refiner = any(k.startswith("context_refiner.0.attention.") for k in normalized_keys)
|
|
has_zimage_text = any(k.startswith("text_encoders.") and "qwen3" in k for k in normalized_keys)
|
|
if has_lumina_backbone and (has_refiner or has_zimage_text):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def _single_file_has_text_encoder_weights(path: str) -> bool:
|
|
if not os.path.isfile(path) or not path.lower().endswith(".safetensors"):
|
|
return False
|
|
try:
|
|
from safetensors import safe_open
|
|
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
for key in f.keys():
|
|
if key.startswith("text_encoders.qwen3_4b.") or key.startswith("text_encoder."):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
|
|
def _is_likely_fp8_single_file(path: str) -> bool:
|
|
if not os.path.isfile(path) or not path.lower().endswith(".safetensors"):
|
|
return False
|
|
name = os.path.basename(path).lower()
|
|
if "fp8" in name:
|
|
return True
|
|
try:
|
|
from safetensors import safe_open
|
|
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
keys = set(f.keys())
|
|
if "scaled_fp8" in keys or "transformer.scaled_fp8" in keys:
|
|
return True
|
|
# Common quantized-sidecar naming patterns.
|
|
if any(k.endswith(".scale") or ".fp8_" in k for k in keys):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
|
|
def should_use_zimage_checkpoint(name: str, checkpoint_folders: list[str]) -> bool:
|
|
matched, _ = inspect_zimage_checkpoint_detection(name, checkpoint_folders)
|
|
return matched
|
|
|
|
|
|
def inspect_zimage_checkpoint_detection(name: str, checkpoint_folders: list[str]) -> tuple[bool, str]:
|
|
raw_name = str(name or "")
|
|
resolved = _resolve_named_path(raw_name, checkpoint_folders)
|
|
if resolved is None:
|
|
return False, "could not resolve model path from checkpoint folders"
|
|
|
|
if os.path.isdir(resolved):
|
|
matched = is_zimage_model_directory(resolved)
|
|
if matched:
|
|
return True, f"resolved directory detected as Z-Image: {resolved}"
|
|
return False, f"resolved directory is not a Z-Image model directory: {resolved}"
|
|
|
|
matched = _is_likely_zimage_safetensors(resolved)
|
|
if matched:
|
|
return True, f"resolved safetensors matched Z-Image fingerprint: {resolved}"
|
|
return False, f"resolved safetensors did not match Z-Image fingerprint: {resolved}"
|
|
|
|
|
|
def _find_local_repo_components(flavor: str, checkpoint_folders: list[str]) -> Optional[str]:
|
|
repo = _repo_for_flavor(flavor).split("/")[-1]
|
|
universal_root = _universal_zimage_root()
|
|
universal_repo_dir = os.path.join(universal_root, repo)
|
|
|
|
# Preferred universal location:
|
|
# models/zimage/<RepoName>/{text_encoder,tokenizer,vae,scheduler}
|
|
if is_zimage_model_directory(universal_repo_dir):
|
|
return universal_repo_dir
|
|
# Backward-compatible fallback:
|
|
# models/zimage/{text_encoder,tokenizer,vae,scheduler}
|
|
if is_zimage_model_directory(universal_root):
|
|
return universal_root
|
|
|
|
candidates = [repo, repo.lower(), repo.replace("-", "_").lower()]
|
|
|
|
for folder in checkpoint_folders:
|
|
for cand in candidates:
|
|
path = os.path.abspath(os.path.realpath(os.path.join(folder, cand)))
|
|
if is_zimage_model_directory(path):
|
|
return path
|
|
# Also support nested mirrors like Tongyi-MAI/Z-Image-Turbo.
|
|
nested = os.path.abspath(os.path.realpath(os.path.join(folder, "Tongyi-MAI", repo)))
|
|
if is_zimage_model_directory(nested):
|
|
return nested
|
|
|
|
return None
|
|
|
|
|
|
def _download_repo_components(repo_id: str, local_config: str, patterns: list[str], missing: list[str]) -> None:
|
|
from huggingface_hub import snapshot_download
|
|
|
|
print(f"[Z-Image POC] Downloading missing components: {', '.join(missing)}")
|
|
snapshot_download(
|
|
repo_id=repo_id,
|
|
local_dir=local_config,
|
|
local_dir_use_symlinks=False,
|
|
allow_patterns=patterns,
|
|
)
|
|
|
|
|
|
def _ensure_single_file_component_dir(
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
single_file_path: Optional[str] = None,
|
|
) -> tuple[str, bool]:
|
|
"""
|
|
Prepare a local component directory for single-file Z-Image loading.
|
|
Policy:
|
|
- tokenizer + text_encoder + vae + scheduler may be auto-downloaded if missing.
|
|
- when single-file already includes text-encoder weights, prefer config-only text_encoder bootstrap.
|
|
"""
|
|
local_config = _find_local_repo_components(flavor, checkpoint_folders)
|
|
repo_id = _repo_for_flavor(flavor)
|
|
repo = repo_id.split("/")[-1]
|
|
universal_root = _universal_zimage_root()
|
|
preferred_local_config = os.path.join(universal_root, repo)
|
|
os.makedirs(preferred_local_config, exist_ok=True)
|
|
|
|
if local_config is None:
|
|
# Use universal repo folder as canonical storage even before it is complete.
|
|
local_config = preferred_local_config
|
|
|
|
need_tokenizer = not os.path.isdir(os.path.join(local_config, "tokenizer"))
|
|
need_text_encoder = not os.path.isdir(os.path.join(local_config, "text_encoder"))
|
|
need_vae = not os.path.isdir(os.path.join(local_config, "vae"))
|
|
need_scheduler = not os.path.isdir(os.path.join(local_config, "scheduler"))
|
|
|
|
tokenizer_json = os.path.join(local_config, "tokenizer", "tokenizer.json")
|
|
expected_tokenizer_sha = _TOKENIZER_JSON_SHA256.get(repo_id)
|
|
if (not need_tokenizer) and expected_tokenizer_sha and os.path.isfile(tokenizer_json):
|
|
try:
|
|
current_sha = _sha256_file(tokenizer_json)
|
|
if current_sha.lower() != expected_tokenizer_sha.lower():
|
|
print("[Z-Image POC] Local tokenizer checksum mismatch, refreshing tokenizer files.")
|
|
need_tokenizer = True
|
|
except Exception as e:
|
|
print(f"[Z-Image POC] Tokenizer checksum check failed, will refresh tokenizer files: {e}")
|
|
need_tokenizer = True
|
|
|
|
if not need_tokenizer and not need_text_encoder and not need_vae and not need_scheduler:
|
|
return local_config, False
|
|
|
|
try_config_only_text_encoder = bool(
|
|
need_text_encoder and single_file_path and _single_file_has_text_encoder_weights(single_file_path)
|
|
)
|
|
|
|
patterns = [
|
|
"model_index.json",
|
|
"transformer/config.json",
|
|
]
|
|
if need_tokenizer:
|
|
patterns.append("tokenizer/*")
|
|
if need_text_encoder:
|
|
if try_config_only_text_encoder:
|
|
patterns.append("text_encoder/*.json")
|
|
else:
|
|
patterns.append("text_encoder/*")
|
|
if need_vae:
|
|
patterns.append("vae/*")
|
|
if need_scheduler:
|
|
patterns.append("scheduler/*")
|
|
|
|
missing = []
|
|
if need_tokenizer:
|
|
missing.append("tokenizer")
|
|
if need_text_encoder:
|
|
if try_config_only_text_encoder:
|
|
missing.append("text_encoder(config)")
|
|
else:
|
|
missing.append("text_encoder")
|
|
if need_vae:
|
|
missing.append("vae")
|
|
if need_scheduler:
|
|
missing.append("scheduler")
|
|
_download_repo_components(repo_id, local_config, patterns, missing)
|
|
return local_config, try_config_only_text_encoder
|
|
|
|
|
|
def _split_single_file_state_dict(single_file_path: str, include_aux_weights: bool = False) -> dict[str, dict]:
|
|
from safetensors import safe_open
|
|
|
|
transformer_prefixes = ["model.diffusion_model.", "diffusion_model.", "transformer."]
|
|
text_prefixes = ["text_encoders.qwen3_4b.", "text_encoder.", "qwen3_4b.transformer.", "qwen3_4b."]
|
|
vae_prefixes = ["vae.", "first_stage_model."]
|
|
|
|
transformer_sd = {}
|
|
text_sd = {}
|
|
vae_sd = {}
|
|
|
|
with safe_open(single_file_path, framework="pt", device="cpu") as f:
|
|
for key in f.keys():
|
|
hit = False
|
|
for pref in transformer_prefixes:
|
|
if key.startswith(pref):
|
|
transformer_sd[key[len(pref):]] = f.get_tensor(key)
|
|
hit = True
|
|
break
|
|
if hit:
|
|
continue
|
|
|
|
if include_aux_weights:
|
|
for pref in text_prefixes:
|
|
if key.startswith(pref):
|
|
text_sd[key[len(pref):]] = f.get_tensor(key)
|
|
hit = True
|
|
break
|
|
if hit:
|
|
continue
|
|
|
|
for pref in vae_prefixes:
|
|
if key.startswith(pref):
|
|
vae_sd[key[len(pref):]] = f.get_tensor(key)
|
|
hit = True
|
|
break
|
|
if hit:
|
|
continue
|
|
|
|
# Plain diffusion-model safetensors often contain bare transformer keys.
|
|
# Keep all non-text/non-vae tensors as transformer weights.
|
|
if any(key.startswith(pref) for pref in text_prefixes):
|
|
if include_aux_weights:
|
|
text_sd[key] = f.get_tensor(key)
|
|
continue
|
|
if any(key.startswith(pref) for pref in vae_prefixes):
|
|
if include_aux_weights:
|
|
vae_sd[key] = f.get_tensor(key)
|
|
continue
|
|
transformer_sd[key] = f.get_tensor(key)
|
|
|
|
return {
|
|
"transformer": transformer_sd,
|
|
"text_encoder": text_sd,
|
|
"vae": vae_sd,
|
|
}
|
|
|
|
|
|
def _convert_z_image_transformer_checkpoint_to_diffusers(checkpoint: dict) -> dict:
|
|
# Ported from diffusers single-file converter for ZImageTransformer2DModel.
|
|
renamed = {}
|
|
for key, value in checkpoint.items():
|
|
new_key = key
|
|
new_key = new_key.replace("final_layer.", "all_final_layer.2-1.")
|
|
new_key = new_key.replace("x_embedder.", "all_x_embedder.2-1.")
|
|
new_key = new_key.replace(".attention.out.bias", ".attention.to_out.0.bias")
|
|
new_key = new_key.replace(".attention.k_norm.weight", ".attention.norm_k.weight")
|
|
new_key = new_key.replace(".attention.q_norm.weight", ".attention.norm_q.weight")
|
|
new_key = new_key.replace(".attention.out.weight", ".attention.to_out.0.weight")
|
|
renamed[new_key] = value
|
|
|
|
converted = {}
|
|
for key, value in renamed.items():
|
|
if ".attention.qkv.weight" in key:
|
|
to_q_weight, to_k_weight, to_v_weight = value.chunk(3, dim=0)
|
|
converted[key.replace(".attention.qkv.weight", ".attention.to_q.weight")] = to_q_weight
|
|
converted[key.replace(".attention.qkv.weight", ".attention.to_k.weight")] = to_k_weight
|
|
converted[key.replace(".attention.qkv.weight", ".attention.to_v.weight")] = to_v_weight
|
|
continue
|
|
if ".attention.qkv.bias" in key:
|
|
to_q_bias, to_k_bias, to_v_bias = value.chunk(3, dim=0)
|
|
converted[key.replace(".attention.qkv.bias", ".attention.to_q.bias")] = to_q_bias
|
|
converted[key.replace(".attention.qkv.bias", ".attention.to_k.bias")] = to_k_bias
|
|
converted[key.replace(".attention.qkv.bias", ".attention.to_v.bias")] = to_v_bias
|
|
continue
|
|
converted[key] = value
|
|
|
|
return converted
|
|
|
|
|
|
def _convert_z_image_transformer_key_to_diffusers(src_key: str) -> list[str]:
|
|
key = src_key
|
|
key = key.replace("final_layer.", "all_final_layer.2-1.")
|
|
key = key.replace("x_embedder.", "all_x_embedder.2-1.")
|
|
key = key.replace(".attention.out.bias", ".attention.to_out.0.bias")
|
|
key = key.replace(".attention.k_norm.weight", ".attention.norm_k.weight")
|
|
key = key.replace(".attention.q_norm.weight", ".attention.norm_q.weight")
|
|
key = key.replace(".attention.out.weight", ".attention.to_out.0.weight")
|
|
if ".attention.qkv.weight" in key:
|
|
return [
|
|
key.replace(".attention.qkv.weight", ".attention.to_q.weight"),
|
|
key.replace(".attention.qkv.weight", ".attention.to_k.weight"),
|
|
key.replace(".attention.qkv.weight", ".attention.to_v.weight"),
|
|
]
|
|
if ".attention.qkv.bias" in key:
|
|
return [
|
|
key.replace(".attention.qkv.bias", ".attention.to_q.bias"),
|
|
key.replace(".attention.qkv.bias", ".attention.to_k.bias"),
|
|
key.replace(".attention.qkv.bias", ".attention.to_v.bias"),
|
|
]
|
|
return [key]
|
|
|
|
|
|
def _transformer_match_score_from_keys(src_keys, model_keys: set[str]) -> int:
|
|
if not src_keys or not model_keys:
|
|
return 0
|
|
|
|
strip_prefixes = [
|
|
"model.diffusion_model.",
|
|
"diffusion_model.",
|
|
"transformer.",
|
|
"model.",
|
|
]
|
|
|
|
matched = 0
|
|
for src_key in src_keys:
|
|
if src_key in model_keys:
|
|
matched += 1
|
|
continue
|
|
for pref in strip_prefixes:
|
|
if src_key.startswith(pref) and src_key[len(pref):] in model_keys:
|
|
matched += 1
|
|
break
|
|
return matched
|
|
|
|
|
|
def _transformer_match_score(state_dict: dict, model_keys: set[str]) -> int:
|
|
return _transformer_match_score_from_keys(state_dict.keys() if state_dict else [], model_keys)
|
|
|
|
|
|
def _choose_transformer_mapping(single_file_path: str, state_dict: dict, model_keys: set[str]) -> dict:
|
|
if not state_dict or not model_keys:
|
|
return {"use_forge_mapping": False, "base_score": 0, "converted_score": 0, "cache_hit": False}
|
|
|
|
cache_key = _transformer_mapping_cache_key(single_file_path, model_keys)
|
|
cached = _mapping_cache_get(cache_key)
|
|
if cached is not None:
|
|
return {
|
|
"use_forge_mapping": bool(cached.get("use_forge_mapping", False)),
|
|
"base_score": int(cached.get("base_score", 0)),
|
|
"converted_score": int(cached.get("converted_score", 0)),
|
|
"cache_hit": True,
|
|
}
|
|
|
|
base_score = _transformer_match_score(state_dict, model_keys)
|
|
converted_keys = []
|
|
for key in state_dict.keys():
|
|
converted_keys.extend(_convert_z_image_transformer_key_to_diffusers(key))
|
|
converted_score = _transformer_match_score_from_keys(converted_keys, model_keys)
|
|
use_forge_mapping = converted_score > base_score
|
|
|
|
_mapping_cache_put(
|
|
cache_key,
|
|
{
|
|
"use_forge_mapping": use_forge_mapping,
|
|
"base_score": base_score,
|
|
"converted_score": converted_score,
|
|
},
|
|
)
|
|
return {
|
|
"use_forge_mapping": use_forge_mapping,
|
|
"base_score": base_score,
|
|
"converted_score": converted_score,
|
|
"cache_hit": False,
|
|
}
|
|
|
|
|
|
def _maybe_convert_transformer_checkpoint(single_file_path: str, state_dict: dict, model_keys: set[str]) -> dict:
|
|
if not state_dict or not model_keys:
|
|
return state_dict
|
|
|
|
decision = _choose_transformer_mapping(single_file_path, state_dict, model_keys)
|
|
if not decision["use_forge_mapping"]:
|
|
return state_dict
|
|
|
|
persisted_sd = _load_persisted_converted_transformer(single_file_path)
|
|
if persisted_sd is not None:
|
|
return persisted_sd
|
|
|
|
if decision["cache_hit"]:
|
|
print(
|
|
f"[Z-Image POC] Reusing cached Forge-style Z-Image transformer key mapping "
|
|
f"(score {decision['base_score']}->{decision['converted_score']})."
|
|
)
|
|
else:
|
|
print(
|
|
f"[Z-Image POC] Using Forge-style Z-Image transformer key mapping "
|
|
f"(score {decision['base_score']}->{decision['converted_score']})."
|
|
)
|
|
converted_sd = _convert_z_image_transformer_checkpoint_to_diffusers(state_dict)
|
|
_save_persisted_converted_transformer(single_file_path, converted_sd)
|
|
return converted_sd
|
|
|
|
|
|
def _load_transformer_weights_from_single_file(single_file_path: str, pipeline) -> None:
|
|
parts = _split_single_file_state_dict(single_file_path, include_aux_weights=True)
|
|
transformer_sd = parts["transformer"]
|
|
transformer_component = getattr(pipeline, "transformer", None)
|
|
transformer_model_keys = (
|
|
set(transformer_component.state_dict().keys()) if transformer_component is not None else set()
|
|
)
|
|
|
|
if transformer_sd and transformer_model_keys:
|
|
try:
|
|
transformer_sd = _maybe_convert_transformer_checkpoint(
|
|
single_file_path, transformer_sd, transformer_model_keys
|
|
)
|
|
except Exception as e:
|
|
print(f"[Z-Image POC] Z-Image transformer conversion skipped: {e}")
|
|
|
|
if not transformer_sd:
|
|
raise RuntimeError(
|
|
"Single-file Z-Image checkpoint does not contain transformer weights in expected format."
|
|
)
|
|
|
|
_load_component_override_from_file(
|
|
pipeline,
|
|
"transformer",
|
|
single_file_path,
|
|
state_dict_override=transformer_sd,
|
|
source_label=f"{os.path.basename(single_file_path)}::transformer",
|
|
)
|
|
|
|
# Optional: if single-file includes text encoder / VAE weights, prefer them.
|
|
if hasattr(pipeline, "text_encoder") and pipeline.text_encoder is not None:
|
|
text_sd = parts["text_encoder"]
|
|
if text_sd:
|
|
try:
|
|
pipeline.text_encoder.load_state_dict(text_sd, strict=False)
|
|
print(f"[Z-Image POC] Loaded text_encoder weights from single-file ({len(text_sd)} tensors).")
|
|
except RuntimeError as e:
|
|
print(f"[Z-Image POC] Skipped text_encoder weights from single-file: {e}")
|
|
|
|
if hasattr(pipeline, "vae") and pipeline.vae is not None:
|
|
vae_sd = parts["vae"]
|
|
if vae_sd:
|
|
try:
|
|
pipeline.vae.load_state_dict(vae_sd, strict=False)
|
|
print(f"[Z-Image POC] Loaded VAE weights from single-file ({len(vae_sd)} tensors).")
|
|
except Exception as e:
|
|
print(f"[Z-Image POC] Skipped VAE weights from single-file: {e}")
|
|
parts["transformer"].clear()
|
|
parts["text_encoder"].clear()
|
|
parts["vae"].clear()
|
|
del parts
|
|
_cleanup_memory(cuda=False)
|
|
|
|
|
|
def _apply_component_state_dict(
|
|
component,
|
|
state_dict: dict,
|
|
label: str,
|
|
missing_limit: Optional[int] = None,
|
|
unexpected_limit: Optional[int] = None,
|
|
):
|
|
if component is None or not state_dict:
|
|
return [], []
|
|
missing, unexpected = component.load_state_dict(state_dict, strict=False)
|
|
if unexpected_limit is not None and len(unexpected) > unexpected_limit:
|
|
raise RuntimeError(f"{label} weight mismatch too large (unexpected={len(unexpected)}).")
|
|
if missing_limit is not None and len(missing) > missing_limit:
|
|
raise RuntimeError(f"{label} weight mismatch too large (missing={len(missing)}).")
|
|
if len(unexpected) > 0 or len(missing) > 0:
|
|
print(f"[Z-Image POC] {label} non-strict load: missing={len(missing)}, unexpected={len(unexpected)}")
|
|
print(f"[Z-Image POC] Loaded {label} from single-file ({len(state_dict)} tensors).")
|
|
return missing, unexpected
|
|
|
|
|
|
def _remap_state_dict_to_model_keys(state_dict: dict, model_keys: set[str], label: str, verbose: bool = True) -> dict:
|
|
if not state_dict:
|
|
return state_dict
|
|
|
|
strip_prefixes = [
|
|
"",
|
|
"model.diffusion_model.",
|
|
"diffusion_model.",
|
|
"transformer.",
|
|
"model.",
|
|
]
|
|
quant_side_suffixes = (
|
|
".comfy_quant",
|
|
".weight_scale",
|
|
".weight_scale_2",
|
|
".input_scale",
|
|
".scale_input",
|
|
".scale_weight",
|
|
)
|
|
|
|
def _map_primary(src_key: str) -> Optional[str]:
|
|
if src_key in model_keys:
|
|
return src_key
|
|
for pref in strip_prefixes[1:]:
|
|
if src_key.startswith(pref):
|
|
cand = src_key[len(pref):]
|
|
if cand in model_keys:
|
|
return cand
|
|
return None
|
|
|
|
remapped = {}
|
|
matched = 0
|
|
for src_key, value in state_dict.items():
|
|
dst_key = _map_primary(src_key)
|
|
if dst_key is None:
|
|
for suffix in quant_side_suffixes:
|
|
if not src_key.endswith(suffix):
|
|
continue
|
|
base_key = src_key[: -len(suffix)]
|
|
mapped_weight = _map_primary(f"{base_key}.weight")
|
|
if mapped_weight is None or not mapped_weight.endswith(".weight"):
|
|
continue
|
|
dst_key = mapped_weight[: -len(".weight")] + suffix
|
|
break
|
|
|
|
if dst_key is not None:
|
|
remapped[dst_key] = value
|
|
matched += 1
|
|
else:
|
|
remapped[src_key] = value
|
|
|
|
total = max(len(state_dict), 1)
|
|
ratio = matched / float(total)
|
|
if verbose and matched != len(state_dict):
|
|
print(f"[Z-Image POC] {label} remap coverage: matched={matched}/{len(state_dict)} ({ratio:.2%})")
|
|
return remapped
|
|
|
|
|
|
def _call_with_dtype_compat(callable_obj, dtype, kwargs: dict, label: str):
|
|
errors = []
|
|
|
|
def _invoke(kwargs_try: dict):
|
|
try:
|
|
return callable_obj(**kwargs_try)
|
|
except TypeError as e:
|
|
# Some loaders reject low_cpu_mem_usage (or dtype args). Retry once without it.
|
|
if "low_cpu_mem_usage" in kwargs_try:
|
|
fallback = dict(kwargs_try)
|
|
fallback.pop("low_cpu_mem_usage", None)
|
|
return callable_obj(**fallback)
|
|
raise e
|
|
|
|
if dtype is not None:
|
|
try:
|
|
return _invoke({**kwargs, "dtype": dtype})
|
|
except TypeError as e:
|
|
errors.append(e)
|
|
except Exception:
|
|
raise
|
|
|
|
try:
|
|
return _invoke({**kwargs, "torch_dtype": dtype})
|
|
except TypeError as e:
|
|
errors.append(e)
|
|
except Exception:
|
|
raise
|
|
|
|
try:
|
|
return _invoke(kwargs)
|
|
except Exception as e:
|
|
if errors:
|
|
print(f"[Z-Image POC] {label} dtype-compat fallback after: {errors[-1]}")
|
|
raise e
|
|
|
|
|
|
def _build_pipeline_from_single_file_components(
|
|
local_config: str,
|
|
single_file_path: str,
|
|
dtype,
|
|
prefer_single_file_aux_weights: bool = False,
|
|
):
|
|
from diffusers import DiffusionPipeline
|
|
from transformers import AutoConfig, AutoModel
|
|
|
|
model_index = _read_json(os.path.join(local_config, "model_index.json"))
|
|
parts = _split_single_file_state_dict(single_file_path, include_aux_weights=prefer_single_file_aux_weights)
|
|
|
|
pipeline_class_name = str(model_index.get("_class_name", ""))
|
|
pipeline_cls = getattr(importlib.import_module("diffusers"), pipeline_class_name, None)
|
|
if pipeline_cls is None:
|
|
raise RuntimeError(
|
|
f"diffusers is missing {pipeline_class_name}. "
|
|
"Install/upgrade with: python -m pip install -U diffusers==0.36.0 transformers==4.56.2 safetensors accelerate"
|
|
)
|
|
|
|
components = {}
|
|
for component_name, spec in model_index.items():
|
|
if component_name.startswith("_"):
|
|
continue
|
|
if not (isinstance(spec, list) and len(spec) == 2):
|
|
continue
|
|
|
|
lib_name, cls_name = spec
|
|
comp_path = os.path.join(local_config, component_name)
|
|
lib = importlib.import_module(lib_name)
|
|
cls = getattr(lib, cls_name)
|
|
|
|
if component_name == "scheduler":
|
|
components[component_name] = cls.from_pretrained(comp_path, local_files_only=True)
|
|
continue
|
|
if component_name.startswith("tokenizer"):
|
|
components[component_name] = cls.from_pretrained(comp_path, local_files_only=True)
|
|
continue
|
|
|
|
want_single_file_weights = bool(parts.get(component_name))
|
|
if component_name == "text_encoder" and want_single_file_weights:
|
|
config = AutoConfig.from_pretrained(comp_path, local_files_only=True, trust_remote_code=True)
|
|
model = AutoModel.from_config(config, trust_remote_code=True)
|
|
elif want_single_file_weights:
|
|
if hasattr(cls, "load_config") and hasattr(cls, "from_config"):
|
|
model = cls.from_config(cls.load_config(comp_path))
|
|
else:
|
|
model = _call_with_dtype_compat(
|
|
cls.from_pretrained,
|
|
dtype,
|
|
{
|
|
"pretrained_model_name_or_path": comp_path,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
},
|
|
f"{component_name}.from_pretrained",
|
|
)
|
|
else:
|
|
model = _call_with_dtype_compat(
|
|
cls.from_pretrained,
|
|
dtype,
|
|
{
|
|
"pretrained_model_name_or_path": comp_path,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
},
|
|
f"{component_name}.from_pretrained",
|
|
)
|
|
|
|
if hasattr(model, "to"):
|
|
model = model.to(dtype=dtype)
|
|
components[component_name] = model
|
|
|
|
pipeline = pipeline_cls(**components)
|
|
|
|
transformer_component = getattr(pipeline, "transformer", None)
|
|
transformer_model_keys = set(transformer_component.state_dict().keys()) if transformer_component is not None else set()
|
|
raw_transformer_sd = parts["transformer"]
|
|
|
|
selected_transformer_sd = raw_transformer_sd
|
|
if raw_transformer_sd and transformer_model_keys:
|
|
try:
|
|
selected_transformer_sd = _maybe_convert_transformer_checkpoint(
|
|
single_file_path, raw_transformer_sd, transformer_model_keys
|
|
)
|
|
except Exception as e:
|
|
print(f"[Z-Image POC] Z-Image transformer conversion skipped: {e}")
|
|
|
|
_load_component_override_from_file(
|
|
pipeline,
|
|
"transformer",
|
|
single_file_path,
|
|
state_dict_override=selected_transformer_sd,
|
|
source_label=f"{os.path.basename(single_file_path)}::transformer",
|
|
)
|
|
if prefer_single_file_aux_weights:
|
|
_apply_component_state_dict(getattr(pipeline, "text_encoder", None), parts["text_encoder"], label="text_encoder")
|
|
_apply_component_state_dict(getattr(pipeline, "vae", None), parts["vae"], label="vae")
|
|
|
|
if not parts["transformer"]:
|
|
raise RuntimeError("Single-file Z-Image checkpoint does not contain transformer weights in expected format.")
|
|
|
|
if isinstance(pipeline, DiffusionPipeline):
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
parts["transformer"].clear()
|
|
parts["text_encoder"].clear()
|
|
parts["vae"].clear()
|
|
del parts
|
|
_cleanup_memory(cuda=False)
|
|
return pipeline
|
|
|
|
|
|
def resolve_zimage_source(name: str, checkpoint_folders: list[str], auto_download_if_missing: bool = False) -> tuple[Optional[str], Optional[str], str]:
|
|
flavor = detect_zimage_flavor(name)
|
|
resolved = _resolve_named_path(name, checkpoint_folders)
|
|
|
|
if resolved is not None:
|
|
if os.path.isdir(resolved) and is_zimage_model_directory(resolved):
|
|
flavor = _detect_zimage_flavor_from_source("directory", resolved, fallback=flavor)
|
|
return "directory", resolved, flavor
|
|
if os.path.isfile(resolved):
|
|
if _is_likely_zimage_safetensors(resolved):
|
|
flavor = _detect_zimage_flavor_from_source("single_file", resolved, fallback=flavor)
|
|
return "single_file", resolved, flavor
|
|
|
|
return None, None, flavor
|
|
|
|
|
|
def _pick_device_and_dtype(zimage_allow_fp16: Optional[bool] = None):
|
|
import torch
|
|
|
|
dtype_override = _zimage_compute_dtype_mode()
|
|
allow_unsafe_fp16 = _truthy_env("FOOOCUS_ZIMAGE_ALLOW_FP16_UNSAFE", "0")
|
|
|
|
def _resolve_dtype(value: str):
|
|
if value in ("bf16", "bfloat16"):
|
|
return torch.bfloat16
|
|
if value in ("fp16", "float16", "half"):
|
|
return torch.float16
|
|
if value in ("fp32", "float32", "full"):
|
|
return torch.float32
|
|
return None
|
|
|
|
if torch.cuda.is_available():
|
|
requested = _resolve_dtype(dtype_override)
|
|
if requested is not None:
|
|
# BF16 fallback on GPUs that do not support it.
|
|
if requested == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
|
print("[Z-Image POC] Requested BF16 but CUDA BF16 is unsupported; falling back to FP16.")
|
|
return "cuda", torch.float16
|
|
if requested == torch.float16 and zimage_allow_fp16 is False and not allow_unsafe_fp16:
|
|
print("[Z-Image POC] Requested FP16 but checkpoint appears fp16-unsafe; falling back to FP32.")
|
|
return "cuda", torch.float32
|
|
return "cuda", requested
|
|
if torch.cuda.is_bf16_supported():
|
|
return "cuda", torch.bfloat16
|
|
if zimage_allow_fp16 is False and not allow_unsafe_fp16:
|
|
print("[Z-Image POC] Checkpoint appears fp16-unsafe; using FP32.")
|
|
return "cuda", torch.float32
|
|
return "cuda", torch.float16
|
|
|
|
requested = _resolve_dtype(dtype_override)
|
|
if requested is not None:
|
|
return "cpu", requested
|
|
return "cpu", torch.float32
|
|
|
|
|
|
def _pipeline_has_meta_tensors(pipeline) -> bool:
|
|
for name in ("transformer", "text_encoder", "vae"):
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
has_offload_hook = bool(getattr(module, "_hf_hook", None))
|
|
if not has_offload_hook:
|
|
try:
|
|
has_offload_hook = any(bool(getattr(m, "_hf_hook", None)) for m in module.modules())
|
|
except Exception:
|
|
has_offload_hook = False
|
|
try:
|
|
for p in module.parameters():
|
|
if getattr(p, "is_meta", False):
|
|
# With accelerate/diffusers offload hooks, meta tensors are expected.
|
|
# Treat only unmanaged meta tensors as corrupted.
|
|
if not has_offload_hook:
|
|
return True
|
|
except Exception:
|
|
continue
|
|
return False
|
|
|
|
|
|
def _cuda_total_vram_gb() -> float:
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
return 0.0
|
|
try:
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
return float(props.total_memory) / float(1024**3)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
|
|
def _choose_memory_mode(device: str, profile: str = "safe") -> tuple[str, float, float, float]:
|
|
import torch
|
|
|
|
if device != "cuda":
|
|
return "full_gpu", 0.0, 0.0, 0.0
|
|
|
|
total_vram_gb = _cuda_total_vram_gb()
|
|
free_vram_gb = 0.0
|
|
try:
|
|
free_bytes, total_bytes = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
free_vram_gb = float(free_bytes) / float(1024**3)
|
|
if total_vram_gb <= 0:
|
|
total_vram_gb = float(total_bytes) / float(1024**3)
|
|
except Exception:
|
|
pass
|
|
|
|
pressure = (free_vram_gb / total_vram_gb) if total_vram_gb > 0 else 0.0
|
|
|
|
forced_mode = _zimage_forced_memory_mode()
|
|
if forced_mode is not None:
|
|
return forced_mode, total_vram_gb, free_vram_gb, pressure
|
|
|
|
# Profile-based policy:
|
|
# safe: maximize stability under low/medium VRAM
|
|
# balanced: faster default on 10-16GB while keeping fallback room
|
|
# speed: prefer GPU residency/perf, accept higher OOM risk
|
|
if profile == "speed":
|
|
if total_vram_gb > 0 and total_vram_gb <= 8.0:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if total_vram_gb > 0 and total_vram_gb <= 12.0:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.10:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.25:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
return "full_gpu", total_vram_gb, free_vram_gb, pressure
|
|
|
|
if profile == "balanced":
|
|
# 11-12GB cards are still borderline for Z-Image Turbo; default to stricter offload.
|
|
if total_vram_gb > 0 and total_vram_gb <= 12.0:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if total_vram_gb > 0 and total_vram_gb <= 16.0:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.12:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.30:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
return "full_gpu", total_vram_gb, free_vram_gb, pressure
|
|
|
|
# safe profile (default)
|
|
if total_vram_gb > 0 and total_vram_gb <= 12.0:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if total_vram_gb > 0 and total_vram_gb <= 16.0:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.15:
|
|
return "sequential_offload", total_vram_gb, free_vram_gb, pressure
|
|
if pressure < 0.35:
|
|
return "model_offload", total_vram_gb, free_vram_gb, pressure
|
|
return "full_gpu", total_vram_gb, free_vram_gb, pressure
|
|
|
|
|
|
def _memory_mode_rank(mode: str) -> int:
|
|
return {"full_gpu": 0, "model_offload": 1, "sequential_offload": 2}.get(str(mode), 0)
|
|
|
|
|
|
def _stricter_memory_mode(lhs: str, rhs: str) -> str:
|
|
return lhs if _memory_mode_rank(lhs) >= _memory_mode_rank(rhs) else rhs
|
|
|
|
|
|
def _module_memory_gb(module) -> float:
|
|
if module is None:
|
|
return 0.0
|
|
total = 0
|
|
try:
|
|
for tensor in module.parameters():
|
|
total += tensor.nelement() * tensor.element_size()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
for tensor in module.buffers():
|
|
total += tensor.nelement() * tensor.element_size()
|
|
except Exception:
|
|
pass
|
|
return float(total) / float(1024**3)
|
|
|
|
|
|
def _module_direct_storage_bytes(module) -> int:
|
|
total = 0
|
|
try:
|
|
for tensor in module.parameters(recurse=False):
|
|
total += tensor.nelement() * tensor.element_size()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
for tensor in module.buffers(recurse=False):
|
|
total += tensor.nelement() * tensor.element_size()
|
|
except Exception:
|
|
pass
|
|
return int(total)
|
|
|
|
|
|
def _is_streamable_leaf_module(module) -> bool:
|
|
direct_bytes = _module_direct_storage_bytes(module)
|
|
if direct_bytes <= 0:
|
|
return False
|
|
|
|
try:
|
|
for name, _ in module.named_parameters(recurse=True):
|
|
if "." in name:
|
|
return False
|
|
except Exception:
|
|
pass
|
|
try:
|
|
for name, _ in module.named_buffers(recurse=True):
|
|
if "." in name:
|
|
return False
|
|
except Exception:
|
|
pass
|
|
return True
|
|
|
|
|
|
def _module_current_device(module) -> Optional[str]:
|
|
try:
|
|
for tensor in module.parameters(recurse=False):
|
|
return str(tensor.device)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
for tensor in module.buffers(recurse=False):
|
|
return str(tensor.device)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _move_streamable_module_to(module, target_device: str) -> bool:
|
|
current = _module_current_device(module)
|
|
if current is not None and str(current) == str(target_device):
|
|
return False
|
|
module.to(target_device)
|
|
return True
|
|
|
|
|
|
def _collect_streamable_leaf_modules(component_name: str, component_module) -> list[tuple[str, object, int]]:
|
|
entries = []
|
|
if component_module is None:
|
|
return entries
|
|
|
|
min_bytes = int(max(0.0, _zimage_deep_patcher_min_module_mb()) * 1024 * 1024)
|
|
for sub_name, module in component_module.named_modules():
|
|
if not _is_streamable_leaf_module(module):
|
|
continue
|
|
storage_bytes = _module_direct_storage_bytes(module)
|
|
if storage_bytes < min_bytes:
|
|
continue
|
|
full_name = component_name if sub_name == "" else f"{component_name}.{sub_name}"
|
|
entries.append((full_name, module, storage_bytes))
|
|
entries.sort(key=lambda x: x[2], reverse=True)
|
|
return entries
|
|
|
|
|
|
def _deep_patcher_stage(pipeline) -> Optional[str]:
|
|
state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if not isinstance(state, dict):
|
|
return None
|
|
return str(state.get("stage", "active"))
|
|
|
|
|
|
def _normalize_deep_patcher_stage(stage: str) -> str:
|
|
value = str(stage or "").strip().lower()
|
|
if "idle" in value:
|
|
return "idle"
|
|
return value or "active"
|
|
|
|
|
|
def _set_deep_patcher_stage(pipeline, stage: str) -> None:
|
|
state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if not isinstance(state, dict):
|
|
return
|
|
state["stage"] = _normalize_deep_patcher_stage(stage)
|
|
|
|
|
|
def _disable_deep_patcher_offload(pipeline, target_device: Optional[str] = None) -> None:
|
|
state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if not isinstance(state, dict):
|
|
pipeline._zimage_deep_patcher_state = None
|
|
pipeline._zimage_deep_patcher_enabled = False
|
|
return
|
|
|
|
hooks = list(state.get("hooks", ()))
|
|
for handle in hooks:
|
|
try:
|
|
handle.remove()
|
|
except Exception:
|
|
pass
|
|
|
|
active_map = state.get("active_gpu_modules", {})
|
|
if isinstance(active_map, dict):
|
|
modules = list(active_map.values())
|
|
else:
|
|
modules = []
|
|
|
|
if target_device is not None:
|
|
for module in modules:
|
|
try:
|
|
_move_streamable_module_to(module, target_device)
|
|
except Exception:
|
|
pass
|
|
for _, module, _ in state.get("module_entries", ()):
|
|
try:
|
|
_move_streamable_module_to(module, target_device)
|
|
except Exception:
|
|
pass
|
|
|
|
pipeline._zimage_deep_patcher_state = None
|
|
pipeline._zimage_deep_patcher_enabled = False
|
|
|
|
|
|
def _enable_deep_patcher_offload(pipeline, device: str, target_mode: str) -> bool:
|
|
if device != "cuda":
|
|
return False
|
|
if target_mode != "sequential_offload":
|
|
return False
|
|
if not _zimage_deep_patcher_enabled():
|
|
return False
|
|
if bool(getattr(pipeline, "_zimage_deep_patcher_blocked", False)):
|
|
return False
|
|
|
|
managed_components = _collect_granular_offload_components(pipeline)
|
|
if not managed_components:
|
|
return False
|
|
|
|
pipeline._zimage_granular_offload_state = None
|
|
pipeline._zimage_granular_offload_enabled = False
|
|
_disable_deep_patcher_offload(pipeline, target_device="cpu")
|
|
_clear_accelerate_offload_hooks(pipeline)
|
|
if _pipeline_has_meta_tensors(pipeline):
|
|
print(
|
|
"[Z-Image POC] Deep patcher unavailable for current pipeline (meta tensors after hook cleanup); "
|
|
"falling back to non-deep offload."
|
|
)
|
|
return False
|
|
|
|
module_entries = []
|
|
for component_name in managed_components:
|
|
component = getattr(pipeline, component_name, None)
|
|
module_entries.extend(_collect_streamable_leaf_modules(component_name, component))
|
|
|
|
if not module_entries:
|
|
print("[Z-Image POC] Deep patcher found no streamable leaf modules; falling back to non-deep offload.")
|
|
return False
|
|
|
|
state = {
|
|
"device": str(device),
|
|
"mode": str(target_mode),
|
|
"managed_components": tuple(managed_components),
|
|
"module_entries": tuple(module_entries),
|
|
"hooks": [],
|
|
"active_gpu_modules": {},
|
|
"stage": "idle",
|
|
"moves_to_gpu": 0,
|
|
"moves_to_cpu": 0,
|
|
"bytes_streamed": 0,
|
|
}
|
|
|
|
for _, module, _ in module_entries:
|
|
module.to("cpu")
|
|
|
|
def _pre_hook(module, _inputs):
|
|
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if not isinstance(deep_state, dict):
|
|
return
|
|
if deep_state.get("stage", "active") == "idle":
|
|
return
|
|
module_id = id(module)
|
|
if module_id in deep_state["active_gpu_modules"]:
|
|
return
|
|
moved = _move_streamable_module_to(module, deep_state["device"])
|
|
if moved:
|
|
deep_state["moves_to_gpu"] += 1
|
|
deep_state["active_gpu_modules"][module_id] = module
|
|
|
|
def _post_hook(module, _inputs, _output):
|
|
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if not isinstance(deep_state, dict):
|
|
return _output
|
|
module_id = id(module)
|
|
if module_id not in deep_state["active_gpu_modules"]:
|
|
return _output
|
|
moved = _move_streamable_module_to(module, "cpu")
|
|
if moved:
|
|
deep_state["moves_to_cpu"] += 1
|
|
deep_state["active_gpu_modules"].pop(module_id, None)
|
|
return _output
|
|
|
|
for _, module, _ in module_entries:
|
|
state["hooks"].append(module.register_forward_pre_hook(_pre_hook))
|
|
state["hooks"].append(module.register_forward_hook(_post_hook))
|
|
|
|
pipeline._zimage_deep_patcher_state = state
|
|
pipeline._zimage_deep_patcher_enabled = True
|
|
total_streamable_gb = sum(entry[2] for entry in module_entries) / float(1024**3)
|
|
print(
|
|
f"[Z-Image POC] Enabled deep patcher offload ({target_mode}); "
|
|
f"streamable_modules={len(module_entries)}, streamable_total={total_streamable_gb:.2f}GB."
|
|
)
|
|
return True
|
|
|
|
|
|
def _collect_granular_offload_components(pipeline) -> list[str]:
|
|
names = []
|
|
for name in _ZIMAGE_GRANULAR_COMPONENT_NAMES:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
if not hasattr(module, "to"):
|
|
continue
|
|
names.append(name)
|
|
return names
|
|
|
|
|
|
def _clear_accelerate_offload_hooks(pipeline) -> None:
|
|
try:
|
|
if hasattr(pipeline, "maybe_free_model_hooks"):
|
|
pipeline.maybe_free_model_hooks()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
remove_all_hooks = getattr(pipeline, "remove_all_hooks", None)
|
|
if callable(remove_all_hooks):
|
|
remove_all_hooks()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _disable_granular_component_offload(pipeline, target_device: Optional[str] = None) -> None:
|
|
_disable_deep_patcher_offload(pipeline, target_device=target_device)
|
|
state = getattr(pipeline, "_zimage_granular_offload_state", None)
|
|
if not isinstance(state, dict):
|
|
pipeline._zimage_granular_offload_state = None
|
|
pipeline._zimage_granular_offload_enabled = False
|
|
return
|
|
|
|
managed = tuple(state.get("managed_components", ()))
|
|
if target_device is not None:
|
|
for name in managed:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None or not hasattr(module, "to"):
|
|
continue
|
|
try:
|
|
module.to(target_device)
|
|
except Exception:
|
|
pass
|
|
|
|
pipeline._zimage_granular_offload_state = None
|
|
pipeline._zimage_granular_offload_enabled = False
|
|
|
|
|
|
def _enable_granular_component_offload(pipeline, device: str, target_mode: str) -> bool:
|
|
if device != "cuda":
|
|
return False
|
|
if target_mode not in ("model_offload", "sequential_offload"):
|
|
return False
|
|
if not _zimage_granular_offload_enabled():
|
|
return False
|
|
|
|
managed = _collect_granular_offload_components(pipeline)
|
|
if not managed:
|
|
return False
|
|
|
|
_disable_deep_patcher_offload(pipeline, target_device="cpu")
|
|
_clear_accelerate_offload_hooks(pipeline)
|
|
if _pipeline_has_meta_tensors(pipeline):
|
|
print(
|
|
"[Z-Image POC] Granular offload unavailable for current pipeline (meta tensors after hook cleanup); "
|
|
"falling back to diffusers offload hooks."
|
|
)
|
|
return False
|
|
|
|
persistent = []
|
|
if target_mode == "model_offload" and "transformer" in managed:
|
|
# Keep denoiser warm on GPU in model_offload mode while parking large auxiliaries on CPU.
|
|
persistent.append("transformer")
|
|
|
|
try:
|
|
for name in managed:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
module.to("cpu")
|
|
|
|
for name in persistent:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
module.to(device)
|
|
except Exception:
|
|
_disable_granular_component_offload(pipeline, target_device="cpu")
|
|
raise
|
|
|
|
state = {
|
|
"device": device,
|
|
"mode": target_mode,
|
|
"managed_components": tuple(managed),
|
|
"persistent_components": tuple(persistent),
|
|
"current_gpu_components": tuple(persistent),
|
|
"last_stage": "init",
|
|
}
|
|
pipeline._zimage_granular_offload_state = state
|
|
pipeline._zimage_granular_offload_enabled = True
|
|
|
|
summaries = []
|
|
for name in managed:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
summaries.append(f"{name}={_module_memory_gb(module):.2f}GB")
|
|
persistent_text = ", ".join(persistent) if persistent else "none"
|
|
print(
|
|
f"[Z-Image POC] Enabled granular component offload ({target_mode}); "
|
|
f"persistent_gpu={persistent_text}; components: {', '.join(summaries)}."
|
|
)
|
|
return True
|
|
|
|
|
|
def _activate_granular_component_set(pipeline, required_components: tuple[str, ...], stage: str) -> None:
|
|
state = getattr(pipeline, "_zimage_granular_offload_state", None)
|
|
if not isinstance(state, dict):
|
|
return
|
|
|
|
managed = tuple(state.get("managed_components", ()))
|
|
if not managed:
|
|
return
|
|
|
|
mode = str(state.get("mode", "model_offload"))
|
|
current = set(state.get("current_gpu_components", ()))
|
|
target = set(required_components)
|
|
if mode == "model_offload":
|
|
target.update(state.get("persistent_components", ()))
|
|
target.intersection_update(managed)
|
|
|
|
to_gpu = [name for name in managed if name in target and name not in current]
|
|
to_cpu = [name for name in managed if name not in target and name in current]
|
|
|
|
for name in to_gpu:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
module.to(state.get("device", "cuda"))
|
|
for name in to_cpu:
|
|
module = getattr(pipeline, name, None)
|
|
if module is None:
|
|
continue
|
|
module.to("cpu")
|
|
|
|
if to_cpu:
|
|
_cleanup_memory(cuda=True, aggressive=False)
|
|
|
|
state["current_gpu_components"] = tuple(name for name in managed if name in target)
|
|
state["last_stage"] = stage
|
|
if to_gpu or to_cpu:
|
|
print(
|
|
f"[Z-Image POC] Granular offload stage={stage}: +{to_gpu or ['none']} / -{to_cpu or ['none']}."
|
|
)
|
|
|
|
|
|
def _prepare_granular_prompt_encode(pipeline, generator_device: str) -> None:
|
|
if generator_device != "cuda":
|
|
return
|
|
if _deep_patcher_stage(pipeline) is not None:
|
|
_set_deep_patcher_stage(pipeline, "prompt_encode")
|
|
return
|
|
_activate_granular_component_set(
|
|
pipeline,
|
|
required_components=("text_encoder", "text_encoder_2"),
|
|
stage="prompt_encode",
|
|
)
|
|
|
|
|
|
def _prepare_granular_pipeline_call(pipeline, generator_device: str, stage: str = "pipeline_call") -> None:
|
|
if generator_device != "cuda":
|
|
return
|
|
if _deep_patcher_stage(pipeline) is not None:
|
|
_set_deep_patcher_stage(pipeline, stage)
|
|
return
|
|
_activate_granular_component_set(
|
|
pipeline,
|
|
required_components=("transformer", "vae"),
|
|
stage=stage,
|
|
)
|
|
|
|
|
|
def _park_granular_components(pipeline, generator_device: str, stage: str = "idle") -> None:
|
|
if generator_device != "cuda":
|
|
return
|
|
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if isinstance(deep_state, dict):
|
|
deep_state["stage"] = _normalize_deep_patcher_stage(stage)
|
|
active_map = deep_state.get("active_gpu_modules", {})
|
|
if isinstance(active_map, dict) and active_map:
|
|
for module in list(active_map.values()):
|
|
try:
|
|
_move_streamable_module_to(module, "cpu")
|
|
deep_state["moves_to_cpu"] = int(deep_state.get("moves_to_cpu", 0)) + 1
|
|
except Exception:
|
|
pass
|
|
active_map.clear()
|
|
return
|
|
_activate_granular_component_set(pipeline, required_components=tuple(), stage=stage)
|
|
|
|
|
|
def _estimate_generation_vram_need_gb(
|
|
width: int,
|
|
height: int,
|
|
max_sequence_length: int,
|
|
use_cfg: bool,
|
|
flavor: str,
|
|
) -> float:
|
|
megapixels = max(0.25, (max(64, int(width)) * max(64, int(height))) / 1_000_000.0)
|
|
base = 4.4 if flavor == "turbo" else 5.8
|
|
pixel_cost = megapixels * 2.0
|
|
seq_cost = max(0.0, float(max_sequence_length) / 256.0) * 1.2
|
|
cfg_cost = 0.4 if use_cfg else 0.0
|
|
estimated = base + pixel_cost + seq_cost + cfg_cost
|
|
return estimated * _zimage_vram_estimate_scale()
|
|
|
|
|
|
def _apply_memory_mode(
|
|
pipeline,
|
|
device: str,
|
|
target_mode: str,
|
|
total_vram_gb: float,
|
|
free_vram_gb: float,
|
|
pressure: float,
|
|
profile: str,
|
|
reason: str = "",
|
|
allow_relax: bool = False,
|
|
) -> tuple[str, bool]:
|
|
used_offload = False
|
|
current_mode = getattr(pipeline, "_zimage_memory_mode", "unset")
|
|
if (not allow_relax) and current_mode in ("sequential_offload", "model_offload", "full_gpu"):
|
|
target_mode = _stricter_memory_mode(current_mode, target_mode)
|
|
|
|
reason_suffix = f", reason={reason}" if reason else ""
|
|
if target_mode in ("model_offload", "sequential_offload") and device == "cuda":
|
|
try:
|
|
if _enable_deep_patcher_offload(pipeline, device=device, target_mode=target_mode):
|
|
pipeline._zimage_memory_mode = target_mode
|
|
used_offload = True
|
|
print(
|
|
f"[Z-Image POC] Using deep patcher {target_mode} "
|
|
f"(total={total_vram_gb:.2f}GB, free={free_vram_gb:.2f}GB, pressure={pressure:.2f}, profile={profile}{reason_suffix})."
|
|
)
|
|
return ("cuda" if device == "cuda" else "cpu"), used_offload
|
|
if _enable_granular_component_offload(pipeline, device=device, target_mode=target_mode):
|
|
pipeline._zimage_memory_mode = target_mode
|
|
used_offload = True
|
|
print(
|
|
f"[Z-Image POC] Using granular {target_mode} "
|
|
f"(total={total_vram_gb:.2f}GB, free={free_vram_gb:.2f}GB, pressure={pressure:.2f}, profile={profile}{reason_suffix})."
|
|
)
|
|
return ("cuda" if device == "cuda" else "cpu"), used_offload
|
|
except Exception as e:
|
|
print(f"[Z-Image POC] Granular offload setup failed, falling back to diffusers offload hooks: {e}")
|
|
|
|
if target_mode == "sequential_offload" and hasattr(pipeline, "enable_sequential_cpu_offload"):
|
|
_disable_granular_component_offload(pipeline, target_device="cpu")
|
|
pipeline.enable_sequential_cpu_offload()
|
|
pipeline._zimage_memory_mode = "sequential_offload"
|
|
used_offload = True
|
|
print(
|
|
f"[Z-Image POC] Using sequential CPU offload "
|
|
f"(total={total_vram_gb:.2f}GB, free={free_vram_gb:.2f}GB, pressure={pressure:.2f}, profile={profile}{reason_suffix})."
|
|
)
|
|
elif target_mode in ("model_offload", "sequential_offload") and hasattr(pipeline, "enable_model_cpu_offload"):
|
|
_disable_granular_component_offload(pipeline, target_device="cpu")
|
|
pipeline.enable_model_cpu_offload()
|
|
pipeline._zimage_memory_mode = "model_offload"
|
|
used_offload = True
|
|
print(
|
|
f"[Z-Image POC] Using model CPU offload "
|
|
f"(total={total_vram_gb:.2f}GB, free={free_vram_gb:.2f}GB, pressure={pressure:.2f}, profile={profile}{reason_suffix})."
|
|
)
|
|
else:
|
|
if current_mode in ("sequential_offload", "model_offload") and not (allow_relax and target_mode == "full_gpu"):
|
|
# Keep existing offload hooks instead of trying to force full-GPU.
|
|
used_offload = True
|
|
else:
|
|
_disable_granular_component_offload(pipeline, target_device=device)
|
|
_clear_accelerate_offload_hooks(pipeline)
|
|
pipeline.to(device)
|
|
pipeline._zimage_memory_mode = "full_gpu"
|
|
print(
|
|
f"[Z-Image POC] Using full-GPU mode "
|
|
f"(total={total_vram_gb:.2f}GB, free={free_vram_gb:.2f}GB, pressure={pressure:.2f}, profile={profile}{reason_suffix})."
|
|
)
|
|
|
|
return ("cuda" if device == "cuda" else "cpu"), used_offload
|
|
|
|
|
|
def _preflight_generation_memory_mode(
|
|
pipeline,
|
|
cache_key: str,
|
|
device: str,
|
|
generator_device: str,
|
|
used_offload: bool,
|
|
profile: str,
|
|
width: int,
|
|
height: int,
|
|
max_sequence_length: int,
|
|
use_cfg: bool,
|
|
flavor: str,
|
|
) -> tuple[str, bool]:
|
|
if device != "cuda" or generator_device != "cuda":
|
|
return generator_device, used_offload
|
|
|
|
base_mode, total_vram_gb, free_vram_gb, pressure = _choose_memory_mode(device, profile=profile)
|
|
target_mode = base_mode
|
|
forced_mode = _zimage_forced_memory_mode()
|
|
estimated_need_gb = _estimate_generation_vram_need_gb(
|
|
width=width,
|
|
height=height,
|
|
max_sequence_length=max_sequence_length,
|
|
use_cfg=use_cfg,
|
|
flavor=flavor,
|
|
)
|
|
reserve_vram_gb = _zimage_reserved_vram_gb(total_vram_gb=total_vram_gb)
|
|
estimate_scale = _zimage_vram_estimate_scale()
|
|
headroom_gb = {"safe": 1.75, "balanced": 1.35, "speed": 0.95}.get(profile, 1.35)
|
|
usable_free_gb = max(0.0, free_vram_gb - reserve_vram_gb)
|
|
gap_gb = usable_free_gb - estimated_need_gb
|
|
host_available_gb, host_total_gb = _system_ram_info_gb()
|
|
host_reserve_gb = _zimage_system_ram_reserve_gb()
|
|
host_usable_gb = max(0.0, host_available_gb - host_reserve_gb)
|
|
|
|
if forced_mode is None:
|
|
if gap_gb < max(0.35, headroom_gb * 0.50):
|
|
target_mode = "sequential_offload"
|
|
elif gap_gb < headroom_gb:
|
|
target_mode = _stricter_memory_mode(target_mode, "model_offload")
|
|
if flavor == "turbo" and target_mode == "model_offload":
|
|
min_gap_for_model_offload = _zimage_model_offload_min_gap_gb()
|
|
if gap_gb < min_gap_for_model_offload:
|
|
target_mode = "sequential_offload"
|
|
# Joint RAM+VRAM controller:
|
|
# If host RAM is tight, avoid the strictest offload mode when VRAM gap allows.
|
|
if host_available_gb > 0.0:
|
|
if target_mode == "sequential_offload" and host_usable_gb < 0.50 and gap_gb >= 0.70:
|
|
target_mode = "model_offload"
|
|
if target_mode == "model_offload" and host_usable_gb < 0.25 and gap_gb >= max(headroom_gb, 1.00):
|
|
target_mode = "full_gpu"
|
|
else:
|
|
target_mode = forced_mode
|
|
|
|
current_mode = getattr(pipeline, "_zimage_memory_mode", "unset")
|
|
should_reapply_mode = _memory_mode_rank(target_mode) > _memory_mode_rank(current_mode)
|
|
allow_relax = False
|
|
|
|
# Allow speed profile to recover from a prior OOM-induced sequential offload once
|
|
# we observe a stable run and enough preflight headroom.
|
|
if (
|
|
forced_mode is None
|
|
and profile == "speed"
|
|
and current_mode == "sequential_offload"
|
|
and target_mode == "model_offload"
|
|
and not bool(getattr(pipeline, "_zimage_last_run_had_oom", False))
|
|
and gap_gb >= max(headroom_gb, 0.9)
|
|
):
|
|
should_reapply_mode = True
|
|
allow_relax = True
|
|
|
|
if forced_mode is not None and target_mode != current_mode:
|
|
should_reapply_mode = True
|
|
allow_relax = True
|
|
|
|
if should_reapply_mode:
|
|
reason = f"preflight est={estimated_need_gb:.2f}GB gap={gap_gb:.2f}GB"
|
|
if forced_mode is not None:
|
|
reason = f"forced by env FOOOCUS_ZIMAGE_FORCE_MEMORY_MODE={forced_mode}"
|
|
elif allow_relax:
|
|
reason = f"preflight relax est={estimated_need_gb:.2f}GB gap={gap_gb:.2f}GB"
|
|
try:
|
|
generator_device, used_offload = _apply_memory_mode(
|
|
pipeline=pipeline,
|
|
device=device,
|
|
target_mode=target_mode,
|
|
total_vram_gb=total_vram_gb,
|
|
free_vram_gb=free_vram_gb,
|
|
pressure=pressure,
|
|
profile=profile,
|
|
reason=reason,
|
|
allow_relax=allow_relax,
|
|
)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
except Exception as e:
|
|
print(
|
|
f"[Z-Image POC] Warning: failed to switch memory mode to '{target_mode}' during preflight: {e}. "
|
|
"Continuing with current mode."
|
|
)
|
|
if current_mode in ("sequential_offload", "model_offload", "full_gpu"):
|
|
pipeline._zimage_memory_mode = current_mode
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
|
|
forced_suffix = f", forced={forced_mode}" if forced_mode is not None else ""
|
|
print(
|
|
f"[Z-Image POC] Preflight VRAM budget: est={estimated_need_gb:.2f}GB, "
|
|
f"est_scale={estimate_scale:.2f}, "
|
|
f"free={free_vram_gb:.2f}GB, reserve={reserve_vram_gb:.2f}GB, usable={usable_free_gb:.2f}GB, "
|
|
f"gap={gap_gb:.2f}GB, base={base_mode}, "
|
|
f"active={getattr(pipeline, '_zimage_memory_mode', 'unset')}{forced_suffix}."
|
|
)
|
|
if host_available_gb > 0.0:
|
|
print(
|
|
f"[Z-Image POC] Host RAM budget: avail={host_available_gb:.2f}GB, total={host_total_gb:.2f}GB, "
|
|
f"reserve={host_reserve_gb:.2f}GB, usable={host_usable_gb:.2f}GB."
|
|
)
|
|
return generator_device, used_offload
|
|
|
|
|
|
def _is_zimage_pipeline(pipeline) -> bool:
|
|
try:
|
|
class_name = str(getattr(pipeline.__class__, "__name__", "")).lower()
|
|
if "zimage" in class_name:
|
|
return True
|
|
except Exception:
|
|
pass
|
|
try:
|
|
cfg = getattr(pipeline, "config", None)
|
|
if isinstance(cfg, dict):
|
|
cfg_name = str(cfg.get("_class_name", "")).lower()
|
|
else:
|
|
cfg_name = str(getattr(cfg, "_class_name", "")).lower()
|
|
if "zimage" in cfg_name:
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
|
|
def _disable_xformers_for_pipeline(pipeline, reason: str = "") -> bool:
|
|
changed = False
|
|
if _is_zimage_pipeline(pipeline):
|
|
transformer = getattr(pipeline, "transformer", None)
|
|
if transformer is not None:
|
|
try:
|
|
if hasattr(transformer, "reset_attention_backend"):
|
|
transformer.reset_attention_backend()
|
|
elif hasattr(transformer, "set_attention_backend"):
|
|
transformer.set_attention_backend("native")
|
|
changed = True
|
|
except Exception:
|
|
pass
|
|
if hasattr(pipeline, "disable_xformers_memory_efficient_attention"):
|
|
try:
|
|
pipeline.disable_xformers_memory_efficient_attention()
|
|
changed = True
|
|
except Exception:
|
|
pass
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
if changed:
|
|
suffix = f" ({reason})" if reason else ""
|
|
print(f"[Z-Image POC] Disabled xFormers attention{suffix}.")
|
|
return changed
|
|
|
|
|
|
def _maybe_enable_xformers(pipeline, profile: str) -> None:
|
|
mode = _zimage_xformers_mode()
|
|
backend_mode = _zimage_attention_backend_mode()
|
|
explicit_backend = backend_mode != "auto"
|
|
|
|
if backend_mode == "native":
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = "native"
|
|
return
|
|
|
|
if mode == "off" and backend_mode in ("auto", "xformers"):
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
return
|
|
if not explicit_backend and mode == "auto" and profile not in ("balanced", "speed"):
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
return
|
|
if getattr(pipeline, "_zimage_xformers_attempted", False):
|
|
return
|
|
pipeline._zimage_xformers_attempted = True
|
|
|
|
if _is_zimage_pipeline(pipeline):
|
|
transformer = getattr(pipeline, "transformer", None)
|
|
if transformer is not None and hasattr(transformer, "set_attention_backend"):
|
|
base_candidates = _zimage_attention_backend_candidates(
|
|
backend_mode,
|
|
allow_xformers=(mode != "off"),
|
|
)
|
|
discovered = _discover_transformer_attention_backends(transformer)
|
|
candidates = _remap_attention_backend_candidates(base_candidates, discovered)
|
|
if discovered:
|
|
print(f"[Z-Image POC] Attention backend capabilities detected: {discovered}")
|
|
if candidates != base_candidates:
|
|
print(f"[Z-Image POC] Attention backend alias remap: {base_candidates} -> {candidates}")
|
|
print(
|
|
f"[Z-Image POC] Attention backend probe start: mode={mode}, backend={backend_mode}, "
|
|
f"candidates={candidates}"
|
|
)
|
|
if any(_is_flash_attention_backend(c) for c in candidates):
|
|
print("[Z-Image POC] Flash attention initiation: probing flash-compatible backends.")
|
|
last_error = None
|
|
remapped_from_error = False
|
|
i = 0
|
|
while i < len(candidates):
|
|
candidate = candidates[i]
|
|
if _is_flash_attention_backend(candidate):
|
|
print(f"[Z-Image POC] Trying flash attention backend '{candidate}'...")
|
|
try:
|
|
transformer.set_attention_backend(candidate)
|
|
pipeline._zimage_xformers_enabled = candidate != "native"
|
|
pipeline._zimage_xformers_strategy = f"dispatch_backend:{candidate}"
|
|
if candidate == "native":
|
|
print(
|
|
f"[Z-Image POC] Using native attention backend for Z-Image "
|
|
f"(mode={mode}, backend={backend_mode})."
|
|
)
|
|
else:
|
|
print(
|
|
f"[Z-Image POC] Enabled attention backend '{candidate}' for Z-Image "
|
|
f"(mode={mode}, backend={backend_mode})."
|
|
)
|
|
return
|
|
except Exception as e:
|
|
last_error = e
|
|
if _is_flash_attention_backend(candidate):
|
|
print(f"[Z-Image POC] Flash attention backend '{candidate}' unavailable: {e}")
|
|
discovered_from_error = _parse_backend_names_from_error(str(e))
|
|
if discovered_from_error and not remapped_from_error:
|
|
remapped_from_error = True
|
|
remapped = _remap_attention_backend_candidates(base_candidates, discovered_from_error)
|
|
if remapped != candidates:
|
|
print(
|
|
f"[Z-Image POC] Attention backend alias remap from runtime error: "
|
|
f"{base_candidates} -> {remapped}"
|
|
)
|
|
candidates = remapped
|
|
i = 0
|
|
continue
|
|
i += 1
|
|
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
if mode == "on" or explicit_backend:
|
|
print(
|
|
f"[Z-Image POC] Failed to enable requested attention backend "
|
|
f"(backend={backend_mode}, mode={mode}): {last_error}"
|
|
)
|
|
return
|
|
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
if mode == "on" or explicit_backend:
|
|
print(
|
|
"[Z-Image POC] Accelerated attention requested but this Z-Image backend lacks "
|
|
"transformer.set_attention_backend(); using native attention."
|
|
)
|
|
return
|
|
|
|
if not hasattr(pipeline, "enable_xformers_memory_efficient_attention"):
|
|
if mode == "on":
|
|
print("[Z-Image POC] xFormers requested but pipeline does not expose xFormers attention API.")
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
return
|
|
|
|
if explicit_backend and backend_mode not in ("xformers",):
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
print(
|
|
f"[Z-Image POC] Attention backend '{backend_mode}' is only supported on Z-Image dispatcher backends; "
|
|
"using native attention."
|
|
)
|
|
return
|
|
|
|
try:
|
|
pipeline.enable_xformers_memory_efficient_attention()
|
|
pipeline._zimage_xformers_enabled = True
|
|
pipeline._zimage_xformers_strategy = "processor_swap"
|
|
print(f"[Z-Image POC] Enabled xFormers memory efficient attention (mode={mode}).")
|
|
except Exception as e:
|
|
pipeline._zimage_xformers_enabled = False
|
|
pipeline._zimage_xformers_strategy = None
|
|
if mode == "on":
|
|
print(f"[Z-Image POC] Failed to enable xFormers attention: {e}")
|
|
|
|
|
|
def _should_cleanup_cuda_cache(profile: str, had_oom: bool, pipeline) -> bool:
|
|
if had_oom:
|
|
return True
|
|
if profile == "safe":
|
|
return True
|
|
|
|
mode = getattr(pipeline, "_zimage_memory_mode", "unset")
|
|
free_gb, _ = _cuda_mem_info_gb()
|
|
low_free = free_gb > 0 and free_gb < 0.9
|
|
|
|
if profile == "balanced":
|
|
# Keep throughput high for offload modes, but still clean up under pressure.
|
|
if mode == "full_gpu" or low_free:
|
|
return True
|
|
return False
|
|
|
|
# speed profile: only clean if VRAM is very tight.
|
|
return free_gb > 0 and free_gb < 0.5
|
|
|
|
|
|
def _maybe_preemptive_cuda_cleanup_before_generation(pipeline, profile: str) -> None:
|
|
if not _zimage_preemptive_cuda_cleanup_enabled():
|
|
return
|
|
|
|
mode = str(getattr(pipeline, "_zimage_memory_mode", "unset"))
|
|
if mode not in ("model_offload", "sequential_offload", "full_gpu"):
|
|
return
|
|
|
|
# Keep speed profile light unless explicitly requested.
|
|
aggressive = _zimage_preemptive_cuda_cleanup_aggressive()
|
|
if profile != "safe" and not aggressive:
|
|
aggressive = False
|
|
|
|
free_before, total_gb = _cuda_mem_info_gb()
|
|
try:
|
|
if hasattr(pipeline, "maybe_free_model_hooks"):
|
|
pipeline.maybe_free_model_hooks()
|
|
except Exception:
|
|
pass
|
|
_cleanup_memory(cuda=True, aggressive=aggressive)
|
|
free_after, _ = _cuda_mem_info_gb()
|
|
|
|
if free_before > 0 and free_after > 0:
|
|
print(
|
|
f"[Z-Image POC] Pre-run CUDA cleanup: free={free_before:.2f}GB->{free_after:.2f}GB "
|
|
f"(total={total_gb:.2f}GB, mode={mode}, aggressive={aggressive})."
|
|
)
|
|
|
|
|
|
def _prepare_pipeline_memory_mode(pipeline, device: str) -> tuple[str, bool]:
|
|
"""
|
|
Returns (generator_device, used_offload_mode).
|
|
generator_device is used to seed torch.Generator.
|
|
"""
|
|
import torch
|
|
|
|
used_offload = False
|
|
profile = _zimage_perf_profile()
|
|
forced_mode = _zimage_forced_memory_mode()
|
|
pipeline._zimage_perf_profile = profile
|
|
|
|
if device == "cuda":
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
except Exception:
|
|
pass
|
|
_maybe_enable_xformers(pipeline, profile)
|
|
|
|
if profile == "safe":
|
|
if hasattr(pipeline, "enable_attention_slicing"):
|
|
pipeline.enable_attention_slicing("max")
|
|
if hasattr(pipeline, "enable_vae_slicing"):
|
|
pipeline.enable_vae_slicing()
|
|
if hasattr(pipeline, "enable_vae_tiling"):
|
|
pipeline.enable_vae_tiling()
|
|
elif profile == "balanced":
|
|
if hasattr(pipeline, "enable_attention_slicing"):
|
|
try:
|
|
pipeline.enable_attention_slicing("auto")
|
|
except Exception:
|
|
pipeline.enable_attention_slicing("max")
|
|
if hasattr(pipeline, "enable_vae_slicing"):
|
|
pipeline.enable_vae_slicing()
|
|
else:
|
|
# speed profile: avoid forcing slicing/tiling if possible
|
|
if hasattr(pipeline, "disable_attention_slicing"):
|
|
try:
|
|
pipeline.disable_attention_slicing()
|
|
except Exception:
|
|
pass
|
|
|
|
target_mode, total_vram_gb, free_vram_gb, pressure = _choose_memory_mode(device, profile=profile)
|
|
reason = ""
|
|
if forced_mode is not None:
|
|
reason = f"forced by env FOOOCUS_ZIMAGE_FORCE_MEMORY_MODE={forced_mode}"
|
|
_, used_offload = _apply_memory_mode(
|
|
pipeline=pipeline,
|
|
device=device,
|
|
target_mode=target_mode,
|
|
total_vram_gb=total_vram_gb,
|
|
free_vram_gb=free_vram_gb,
|
|
pressure=pressure,
|
|
profile=profile,
|
|
reason=reason,
|
|
allow_relax=(forced_mode is not None),
|
|
)
|
|
else:
|
|
_disable_granular_component_offload(pipeline, target_device=device)
|
|
pipeline.to(device)
|
|
|
|
return ("cuda" if device == "cuda" else "cpu"), used_offload
|
|
|
|
|
|
def _ensure_zimage_runtime_compatibility() -> None:
|
|
missing = []
|
|
try:
|
|
import diffusers
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Z-Image runtime missing diffusers ({e}). "
|
|
"Install/upgrade with: python -m pip install -U diffusers==0.36.0 transformers==4.56.2 safetensors accelerate"
|
|
) from e
|
|
|
|
try:
|
|
import transformers
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Z-Image runtime missing transformers ({e}). "
|
|
"Install/upgrade with: python -m pip install -U transformers==4.56.2"
|
|
) from e
|
|
|
|
if not hasattr(diffusers, "ZImagePipeline"):
|
|
missing.append("diffusers.ZImagePipeline")
|
|
if not hasattr(diffusers, "ZImageTransformer2DModel"):
|
|
missing.append("diffusers.ZImageTransformer2DModel")
|
|
if not hasattr(transformers, "Qwen3Model"):
|
|
missing.append("transformers.Qwen3Model")
|
|
|
|
if missing:
|
|
dv = getattr(diffusers, "__version__", "unknown")
|
|
tv = getattr(transformers, "__version__", "unknown")
|
|
raise RuntimeError(
|
|
"Z-Image backend is too old for this model. Missing: "
|
|
+ ", ".join(missing)
|
|
+ f". Current versions: diffusers={dv}, transformers={tv}. "
|
|
"Install/upgrade with: python -m pip install -U diffusers==0.36.0 transformers==4.56.2 safetensors accelerate"
|
|
)
|
|
|
|
|
|
def _load_component_override(pipeline, component_name: str, component_path: str, dtype) -> None:
|
|
component_path = os.path.abspath(component_path)
|
|
component = getattr(pipeline, component_name, None)
|
|
if component is None:
|
|
print(f"[Z-Image POC] Pipeline has no '{component_name}' component; ignoring override.")
|
|
return
|
|
if os.path.isfile(component_path):
|
|
_load_component_override_from_file(pipeline, component_name, component_path)
|
|
return
|
|
|
|
cls = component.__class__
|
|
kwargs = {
|
|
"pretrained_model_name_or_path": component_path,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
}
|
|
if component_name == "text_encoder":
|
|
kwargs["trust_remote_code"] = True
|
|
|
|
model = _call_with_dtype_compat(
|
|
cls.from_pretrained,
|
|
dtype,
|
|
kwargs,
|
|
f"{component_name}.override.from_pretrained",
|
|
)
|
|
if hasattr(model, "to"):
|
|
model = model.to(dtype=dtype)
|
|
|
|
if hasattr(pipeline, "register_modules"):
|
|
pipeline.register_modules(**{component_name: model})
|
|
else:
|
|
setattr(pipeline, component_name, model)
|
|
print(f"[Z-Image POC] Using override {component_name}: {component_path}")
|
|
|
|
|
|
def _load_component_override_from_file(
|
|
pipeline,
|
|
component_name: str,
|
|
component_file: str,
|
|
state_dict_override: Optional[dict] = None,
|
|
source_label: Optional[str] = None,
|
|
) -> None:
|
|
import torch
|
|
|
|
component = getattr(pipeline, component_name, None)
|
|
if component is None:
|
|
print(f"[Z-Image POC] Pipeline has no '{component_name}' component; ignoring file override.")
|
|
return
|
|
|
|
file_path = os.path.abspath(component_file)
|
|
if state_dict_override is not None:
|
|
state_dict = dict(state_dict_override)
|
|
else:
|
|
state_dict = None
|
|
if file_path.lower().endswith(".safetensors"):
|
|
from safetensors.torch import load_file as safetensors_load_file
|
|
state_dict = safetensors_load_file(file_path, device="cpu")
|
|
else:
|
|
raw = torch.load(file_path, map_location="cpu")
|
|
if isinstance(raw, dict) and isinstance(raw.get("state_dict"), dict):
|
|
state_dict = raw["state_dict"]
|
|
elif isinstance(raw, dict):
|
|
state_dict = raw
|
|
else:
|
|
raise RuntimeError(f"Unsupported override weights format: {file_path}")
|
|
|
|
quant_side_suffixes = (
|
|
".comfy_quant",
|
|
".weight_scale",
|
|
".weight_scale_2",
|
|
".input_scale",
|
|
".scale_input",
|
|
".scale_weight",
|
|
)
|
|
|
|
def _decode_fp4_e2m1_packed_u8(packed: torch.Tensor) -> torch.Tensor:
|
|
# Comfy packs two fp4 values per byte as:
|
|
# packed = (fp4_even << 4) | fp4_odd
|
|
if packed.dtype != torch.uint8:
|
|
packed = packed.to(torch.uint8)
|
|
hi = (packed >> 4) & 0x0F
|
|
lo = packed & 0x0F
|
|
unpacked = torch.empty((packed.shape[0], packed.shape[1] * 2), dtype=torch.uint8, device=packed.device)
|
|
unpacked[:, 0::2] = hi
|
|
unpacked[:, 1::2] = lo
|
|
|
|
# fp4 e2m1 decode table mirrored from Comfy float quantizer behavior.
|
|
table = torch.tensor(
|
|
[
|
|
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
|
],
|
|
dtype=torch.float32,
|
|
device=packed.device,
|
|
)
|
|
return table[unpacked.long()]
|
|
|
|
def _from_blocked_scales(blocked: torch.Tensor) -> torch.Tensor:
|
|
# Inverse of Comfy's to_blocked(...) layout used for NVFP4 block scales.
|
|
if blocked.ndim != 2:
|
|
return blocked
|
|
rows, cols = blocked.shape
|
|
if rows % 128 != 0 or cols % 4 != 0:
|
|
return blocked
|
|
n_row_blocks = rows // 128
|
|
n_col_blocks = cols // 4
|
|
e = blocked.reshape(n_row_blocks * n_col_blocks, 32, 16)
|
|
d = e.reshape(n_row_blocks * n_col_blocks, 32, 4, 4)
|
|
c = d.transpose(1, 2)
|
|
b = c.reshape(n_row_blocks, n_col_blocks, 128, 4)
|
|
a = b.permute(0, 2, 1, 3)
|
|
return a.reshape(rows, cols)
|
|
|
|
def _decode_comfy_quant_entry(raw: torch.Tensor) -> Optional[dict]:
|
|
try:
|
|
return json.loads(raw.detach().cpu().numpy().tobytes().decode("utf-8"))
|
|
except Exception:
|
|
return None
|
|
|
|
class _ComfyRuntimeQuantLinear(torch.nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, compute_dtype=torch.bfloat16):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.compute_dtype = compute_dtype
|
|
self.quant_format: Optional[str] = None
|
|
self.full_precision_mm = False
|
|
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=compute_dtype), requires_grad=False)
|
|
else:
|
|
self.bias = None
|
|
|
|
# Keep optional tensors as plain attrs to avoid None-valued module
|
|
# params/buffers confusing offload hooks.
|
|
self._dense_weight = None
|
|
self._quant_weight = None
|
|
self._weight_scale = None
|
|
self._weight_scale_2 = None
|
|
self._input_scale = None
|
|
|
|
self._cached_weight = None
|
|
self._cached_weight_device: Optional[str] = None
|
|
self._cached_weight_dtype: Optional[str] = None
|
|
|
|
@property
|
|
def weight(self) -> torch.Tensor:
|
|
# Compatibility shim: some external/offload paths still probe
|
|
# module.weight.{device,dtype} even for custom Linear-like modules.
|
|
if self._dense_weight is not None:
|
|
return self._dense_weight
|
|
if self._quant_weight is not None:
|
|
return self._quant_weight
|
|
# Keep attribute contract even during transient init states.
|
|
return torch.empty(0, dtype=self.compute_dtype)
|
|
|
|
@classmethod
|
|
def from_linear(cls, linear_module):
|
|
weight = getattr(linear_module, "weight", None)
|
|
if weight is None or not isinstance(weight, torch.Tensor) or weight.ndim != 2:
|
|
raise RuntimeError("Module is not linear-like (missing 2D weight tensor).")
|
|
|
|
in_features = int(getattr(linear_module, "in_features", weight.shape[1]))
|
|
out_features = int(getattr(linear_module, "out_features", weight.shape[0]))
|
|
bias = getattr(linear_module, "bias", None)
|
|
compute_dtype = getattr(weight, "dtype", torch.bfloat16)
|
|
|
|
layer = cls(
|
|
in_features=in_features,
|
|
out_features=out_features,
|
|
bias=bias is not None,
|
|
compute_dtype=compute_dtype,
|
|
)
|
|
if bias is not None and layer.bias is not None:
|
|
layer.bias.data.copy_(bias.detach())
|
|
layer._dense_weight = weight.detach()
|
|
return layer
|
|
|
|
def _clear_cache(self):
|
|
self._cached_weight = None
|
|
self._cached_weight_device = None
|
|
self._cached_weight_dtype = None
|
|
|
|
def _cache_enabled(self) -> bool:
|
|
return _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_CACHE", "0")
|
|
|
|
def _set_dense_weight(self, weight: torch.Tensor):
|
|
self.quant_format = None
|
|
self._dense_weight = weight.detach()
|
|
self._quant_weight = None
|
|
self._weight_scale = None
|
|
self._weight_scale_2 = None
|
|
self._input_scale = None
|
|
self.full_precision_mm = False
|
|
self._clear_cache()
|
|
|
|
def _set_quant_state(
|
|
self,
|
|
fmt: str,
|
|
weight: torch.Tensor,
|
|
weight_scale: Optional[torch.Tensor],
|
|
weight_scale_2: Optional[torch.Tensor],
|
|
input_scale: Optional[torch.Tensor],
|
|
full_precision_mm: bool,
|
|
):
|
|
self.quant_format = fmt
|
|
self.full_precision_mm = full_precision_mm
|
|
self._quant_weight = weight.detach()
|
|
self._weight_scale = None if weight_scale is None else weight_scale.detach()
|
|
self._weight_scale_2 = None if weight_scale_2 is None else weight_scale_2.detach()
|
|
self._input_scale = None if input_scale is None else input_scale.detach()
|
|
self._dense_weight = None
|
|
self._clear_cache()
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
weight_key = f"{prefix}weight"
|
|
bias_key = f"{prefix}bias"
|
|
comfy_quant_key = f"{prefix}comfy_quant"
|
|
weight_scale_key = f"{prefix}weight_scale"
|
|
weight_scale_2_key = f"{prefix}weight_scale_2"
|
|
input_scale_key = f"{prefix}input_scale"
|
|
scale_input_key = f"{prefix}scale_input"
|
|
scale_weight_key = f"{prefix}scale_weight"
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
bias = state_dict.pop(bias_key, None)
|
|
comfy_quant_raw = state_dict.pop(comfy_quant_key, None)
|
|
weight_scale = state_dict.pop(weight_scale_key, None)
|
|
weight_scale_2 = state_dict.pop(weight_scale_2_key, None)
|
|
input_scale = state_dict.pop(input_scale_key, None)
|
|
if input_scale is None:
|
|
input_scale = state_dict.pop(scale_input_key, None)
|
|
state_dict.pop(scale_weight_key, None)
|
|
|
|
if bias is not None and self.bias is not None:
|
|
self.bias.data = bias.detach().to(device=self.bias.device, dtype=self.bias.dtype)
|
|
|
|
layer_conf = _decode_comfy_quant_entry(comfy_quant_raw) if comfy_quant_raw is not None else None
|
|
|
|
if layer_conf is None:
|
|
if weight is not None:
|
|
self._set_dense_weight(weight.to(dtype=self.compute_dtype))
|
|
return
|
|
|
|
if weight is None:
|
|
raise RuntimeError(f"Quantized layer at '{prefix}' is missing weight tensor.")
|
|
|
|
fmt = str(layer_conf.get("format", "")).lower()
|
|
if fmt not in ("float8_e4m3fn", "float8_e5m2", "nvfp4"):
|
|
raise RuntimeError(f"Unsupported Comfy quant format '{fmt}' for layer '{prefix}'.")
|
|
|
|
full_precision_mm = bool(layer_conf.get("full_precision_matrix_mult", False))
|
|
self._set_quant_state(
|
|
fmt=fmt,
|
|
weight=weight,
|
|
weight_scale=weight_scale,
|
|
weight_scale_2=weight_scale_2,
|
|
input_scale=input_scale,
|
|
full_precision_mm=full_precision_mm,
|
|
)
|
|
|
|
def _dequantize_weight(self, device, dtype):
|
|
if self.quant_format is None:
|
|
if self._dense_weight is None:
|
|
raise RuntimeError("Dense weight is unavailable.")
|
|
return self._dense_weight.to(device=device, dtype=dtype)
|
|
|
|
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
|
if self._quant_weight is None:
|
|
raise RuntimeError("FP8 quantized weight is unavailable.")
|
|
weight = self._quant_weight.to(device=device).float()
|
|
if self._weight_scale is not None:
|
|
weight = weight * self._weight_scale.to(device=device).float()
|
|
return weight.to(dtype=dtype)
|
|
|
|
if self.quant_format == "nvfp4":
|
|
if self._quant_weight is None or self._weight_scale is None or self._weight_scale_2 is None:
|
|
raise RuntimeError("NVFP4 quantized weight/scales are unavailable.")
|
|
deq_values = _decode_fp4_e2m1_packed_u8(self._quant_weight.to(device=device))
|
|
block_scale = _from_blocked_scales(self._weight_scale.to(device=device).float())
|
|
expanded_scale = block_scale.repeat_interleave(16, dim=1) * self._weight_scale_2.to(device=device).float()
|
|
if deq_values.shape != expanded_scale.shape:
|
|
raise RuntimeError(
|
|
f"NVFP4 shape mismatch in runtime linear: values={tuple(deq_values.shape)} "
|
|
f"scales={tuple(expanded_scale.shape)}"
|
|
)
|
|
return (deq_values * expanded_scale).to(dtype=dtype)
|
|
|
|
raise RuntimeError(f"Unknown quant format '{self.quant_format}'.")
|
|
|
|
def _resolve_scalar_scale(self, scale_tensor: Optional[torch.Tensor], device: torch.device) -> Optional[torch.Tensor]:
|
|
if scale_tensor is None:
|
|
return torch.ones((), device=device, dtype=torch.float32)
|
|
try:
|
|
scale = scale_tensor.to(device=device, dtype=torch.float32)
|
|
except Exception:
|
|
return None
|
|
if scale.numel() != 1:
|
|
return None
|
|
return scale.reshape(())
|
|
|
|
def _apply_weight_scale(self, output: torch.Tensor, device: torch.device) -> Optional[torch.Tensor]:
|
|
scale = self._weight_scale
|
|
if scale is None:
|
|
return output
|
|
try:
|
|
scale = scale.to(device=device, dtype=output.dtype)
|
|
except Exception:
|
|
return None
|
|
if scale.ndim == 0:
|
|
return output * scale
|
|
if scale.ndim == 1 and scale.shape[0] == self.out_features:
|
|
return output * scale.view(1, -1)
|
|
if scale.ndim == 2 and scale.shape == (self.out_features, 1):
|
|
return output * scale.view(1, -1)
|
|
return None
|
|
|
|
def _try_fp8_scaled_mm_linear(self, x: torch.Tensor):
|
|
if not _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_FAST_FP8", "1"):
|
|
return None
|
|
if self.quant_format not in ("float8_e4m3fn", "float8_e5m2"):
|
|
return None
|
|
if self.full_precision_mm:
|
|
return None
|
|
if self._quant_weight is None:
|
|
return None
|
|
if not x.is_cuda:
|
|
return None
|
|
if x.ndim != 2:
|
|
return None
|
|
if x.shape[1] != self.in_features:
|
|
return None
|
|
|
|
scaled_mm = getattr(torch, "_scaled_mm", None)
|
|
if scaled_mm is None:
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_COMFY_RUNTIME_FAST_FP8",
|
|
"[Z-Image POC] torch._scaled_mm is unavailable; falling back to standard FP8 linear path.",
|
|
)
|
|
return None
|
|
|
|
# cuBLASLt fp8 path requires K and N to be multiples of 16.
|
|
if (self.in_features % 16) != 0 or (self.out_features % 16) != 0:
|
|
return None
|
|
|
|
try:
|
|
weight = self._quant_weight.to(device=x.device)
|
|
if weight.ndim != 2:
|
|
return None
|
|
if weight.shape != (self.out_features, self.in_features):
|
|
return None
|
|
|
|
fp8_dtype = weight.dtype
|
|
if fp8_dtype not in (getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e5m2", None)):
|
|
return None
|
|
|
|
# Keep values in fp8 finite range before conversion.
|
|
fp8_max = torch.finfo(fp8_dtype).max
|
|
x_fp8 = torch.clamp(x, min=-fp8_max, max=fp8_max).to(dtype=fp8_dtype).contiguous()
|
|
|
|
# Use column-major RHS as required by _scaled_mm/cuBLASLt.
|
|
w_t = weight.t()
|
|
if w_t.stride(0) != 1:
|
|
w_t = weight.contiguous().t()
|
|
if w_t.stride(0) != 1:
|
|
return None
|
|
|
|
out_dtype = x.dtype if x.dtype in (torch.float16, torch.bfloat16, torch.float32) else self.compute_dtype
|
|
scale_a = self._resolve_scalar_scale(self._input_scale, x.device)
|
|
if scale_a is None:
|
|
return None
|
|
|
|
# _scaled_mm tensor-wise mode accepts only singleton scale_b.
|
|
# For non-singleton weight scales, run matmul with scale_b=1 and apply
|
|
# the per-channel scale on the output (same math as current fallback path).
|
|
scale_b = self._resolve_scalar_scale(self._weight_scale, x.device)
|
|
apply_weight_scale_after = scale_b is None and self._weight_scale is not None
|
|
if scale_b is None:
|
|
scale_b = torch.ones((), device=x.device, dtype=torch.float32)
|
|
|
|
bias = self.bias.to(device=x.device, dtype=out_dtype) if self.bias is not None else None
|
|
|
|
try:
|
|
output = scaled_mm(
|
|
x_fp8,
|
|
w_t,
|
|
out_dtype=out_dtype,
|
|
bias=bias,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
except TypeError:
|
|
# Older torch builds may not accept bias argument in _scaled_mm.
|
|
output = scaled_mm(
|
|
x_fp8,
|
|
w_t,
|
|
out_dtype=out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
if bias is not None:
|
|
output = output + bias
|
|
|
|
if isinstance(output, tuple):
|
|
output = output[0]
|
|
|
|
if apply_weight_scale_after:
|
|
output = self._apply_weight_scale(output, x.device)
|
|
if output is None:
|
|
return None
|
|
_warn_once_env(
|
|
"FOOOCUS_ZIMAGE_COMFY_RUNTIME_FAST_FP8",
|
|
"[Z-Image POC] Using torch._scaled_mm fast FP8 path for runtime quantized linear layers.",
|
|
)
|
|
return output
|
|
except Exception:
|
|
return None
|
|
|
|
def _try_fp8_direct_linear(self, x: torch.Tensor):
|
|
if self.quant_format not in ("float8_e4m3fn", "float8_e5m2"):
|
|
return None
|
|
if self.full_precision_mm:
|
|
return None
|
|
if self._quant_weight is None:
|
|
return None
|
|
|
|
# Fast path: torch._scaled_mm for FP8 when supported.
|
|
output = self._try_fp8_scaled_mm_linear(x)
|
|
if output is not None:
|
|
return output
|
|
|
|
try:
|
|
weight = self._quant_weight.to(device=x.device)
|
|
output = torch.nn.functional.linear(x, weight, None)
|
|
# Some torch builds can return FP8 here when either operand is FP8.
|
|
# Keep runtime activations in a compute-friendly float dtype.
|
|
if output.dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
|
output = output.to(dtype=x.dtype if x.dtype in (torch.float16, torch.bfloat16, torch.float32) else torch.bfloat16)
|
|
output = self._apply_weight_scale(output, x.device)
|
|
if output is None:
|
|
return None
|
|
|
|
if self.bias is not None:
|
|
output = output + self.bias.to(device=x.device, dtype=output.dtype)
|
|
return output
|
|
except Exception:
|
|
return None
|
|
|
|
def _runtime_weight(self, x: torch.Tensor):
|
|
force_lowp = _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_LOWP", "1")
|
|
if self.quant_format is not None and force_lowp:
|
|
dtype = self.compute_dtype if self.compute_dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
|
|
else:
|
|
dtype = x.dtype if x.dtype in (torch.float16, torch.bfloat16, torch.float32) else self.compute_dtype
|
|
|
|
# FP16 + FP8 runtime-quant can become numerically unstable on some stacks.
|
|
# Allow stable accumulation dtype while keeping outer runtime in FP16 mode.
|
|
if self.quant_format is not None and dtype == torch.float16:
|
|
mode = _zimage_fp16_quant_accum_mode()
|
|
if mode == "bf16":
|
|
if x.is_cuda and torch.cuda.is_bf16_supported():
|
|
dtype = torch.bfloat16
|
|
elif mode == "fp32":
|
|
dtype = torch.float32
|
|
elif mode == "auto":
|
|
if x.is_cuda and torch.cuda.is_bf16_supported():
|
|
dtype = torch.bfloat16
|
|
else:
|
|
dtype = torch.float32
|
|
|
|
device_key = str(x.device)
|
|
dtype_key = str(dtype)
|
|
if self._cache_enabled():
|
|
if (
|
|
self._cached_weight is not None
|
|
and self._cached_weight_device == device_key
|
|
and self._cached_weight_dtype == dtype_key
|
|
):
|
|
return self._cached_weight
|
|
|
|
weight = self._dequantize_weight(device=x.device, dtype=dtype)
|
|
if self._cache_enabled():
|
|
self._cached_weight = weight
|
|
self._cached_weight_device = device_key
|
|
self._cached_weight_dtype = dtype_key
|
|
return weight
|
|
|
|
def forward(self, input: torch.Tensor):
|
|
input_shape = input.shape
|
|
x = input.reshape(-1, input_shape[-1]) if input.ndim > 2 else input
|
|
input_runtime_dtype = x.dtype
|
|
force_lowp = _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_LOWP", "1")
|
|
if self.quant_format is not None and force_lowp:
|
|
compute_dtype = self.compute_dtype if self.compute_dtype in (torch.float16, torch.bfloat16) else torch.bfloat16
|
|
else:
|
|
compute_dtype = x.dtype if x.dtype in (torch.float16, torch.bfloat16, torch.float32) else self.compute_dtype
|
|
|
|
if self.quant_format is not None and compute_dtype == torch.float16:
|
|
mode = _zimage_fp16_quant_accum_mode()
|
|
if mode == "bf16":
|
|
if x.is_cuda and torch.cuda.is_bf16_supported():
|
|
compute_dtype = torch.bfloat16
|
|
elif mode == "fp32":
|
|
compute_dtype = torch.float32
|
|
elif mode == "auto":
|
|
if x.is_cuda and torch.cuda.is_bf16_supported():
|
|
compute_dtype = torch.bfloat16
|
|
else:
|
|
compute_dtype = torch.float32
|
|
|
|
x = x.to(dtype=compute_dtype)
|
|
|
|
output = self._try_fp8_direct_linear(x)
|
|
if output is None:
|
|
weight = self._runtime_weight(x)
|
|
bias = self.bias.to(device=x.device, dtype=compute_dtype) if self.bias is not None else None
|
|
output = torch.nn.functional.linear(x, weight, bias)
|
|
|
|
# Keep module interface dtype stable, but never propagate FP8 activations.
|
|
if self.quant_format is not None:
|
|
target_dtype = input_runtime_dtype
|
|
if target_dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
|
target_dtype = compute_dtype if compute_dtype in (torch.float16, torch.bfloat16, torch.float32) else torch.bfloat16
|
|
if output.dtype != target_dtype:
|
|
output = output.to(dtype=target_dtype)
|
|
|
|
if input.ndim > 2:
|
|
output = output.reshape(*input_shape[:-1], self.out_features)
|
|
return output
|
|
|
|
def _resolve_module(root_module, module_path: str):
|
|
current = root_module
|
|
for part in module_path.split("."):
|
|
if part.isdigit():
|
|
current = current[int(part)]
|
|
else:
|
|
current = getattr(current, part)
|
|
return current
|
|
|
|
def _set_module(root_module, module_path: str, new_module):
|
|
parts = module_path.split(".")
|
|
current = root_module
|
|
for part in parts[:-1]:
|
|
if part.isdigit():
|
|
current = current[int(part)]
|
|
else:
|
|
current = getattr(current, part)
|
|
leaf = parts[-1]
|
|
if leaf.isdigit():
|
|
current[int(leaf)] = new_module
|
|
else:
|
|
setattr(current, leaf, new_module)
|
|
|
|
def _is_linear_like_module(module) -> bool:
|
|
weight = getattr(module, "weight", None)
|
|
if weight is None or not isinstance(weight, torch.Tensor) or weight.ndim != 2:
|
|
return False
|
|
in_features = getattr(module, "in_features", None)
|
|
out_features = getattr(module, "out_features", None)
|
|
if in_features is None or out_features is None:
|
|
out_features, in_features = weight.shape[0], weight.shape[1]
|
|
try:
|
|
in_features = int(in_features)
|
|
out_features = int(out_features)
|
|
except Exception:
|
|
return False
|
|
return in_features > 0 and out_features > 0
|
|
|
|
def _install_comfy_runtime_quant_modules(component_module, remapped_sd: dict) -> dict:
|
|
bases = sorted(
|
|
{
|
|
key[: -len(".comfy_quant")]
|
|
for key in remapped_sd.keys()
|
|
if key.endswith(".comfy_quant") and f"{key[: -len('.comfy_quant')]}.weight" in remapped_sd
|
|
}
|
|
)
|
|
if not bases:
|
|
return {"layers": 0, "replaced": 0, "skipped": 0, "float8": 0, "nvfp4": 0}
|
|
|
|
replaced = 0
|
|
skipped = 0
|
|
float8_layers = 0
|
|
nvfp4_layers = 0
|
|
replaced_bases = set()
|
|
skipped_type_counts = {}
|
|
skipped_unresolved = 0
|
|
|
|
def _mark_skipped_type(name: str):
|
|
skipped_type_counts[name] = skipped_type_counts.get(name, 0) + 1
|
|
|
|
for base in bases:
|
|
conf = _decode_comfy_quant_entry(remapped_sd.get(f"{base}.comfy_quant"))
|
|
fmt = str(conf.get("format", "")).lower() if isinstance(conf, dict) else ""
|
|
if fmt in ("float8_e4m3fn", "float8_e5m2"):
|
|
float8_layers += 1
|
|
elif fmt == "nvfp4":
|
|
nvfp4_layers += 1
|
|
|
|
try:
|
|
target = _resolve_module(component_module, base)
|
|
except Exception:
|
|
skipped += 1
|
|
skipped_unresolved += 1
|
|
continue
|
|
|
|
if isinstance(target, _ComfyRuntimeQuantLinear):
|
|
replaced += 1
|
|
replaced_bases.add(base)
|
|
continue
|
|
if not _is_linear_like_module(target):
|
|
skipped += 1
|
|
_mark_skipped_type(type(target).__name__)
|
|
continue
|
|
try:
|
|
replacement = _ComfyRuntimeQuantLinear.from_linear(target)
|
|
except Exception:
|
|
skipped += 1
|
|
_mark_skipped_type(type(target).__name__)
|
|
continue
|
|
_set_module(component_module, base, replacement)
|
|
replaced += 1
|
|
replaced_bases.add(base)
|
|
|
|
return {
|
|
"layers": len(bases),
|
|
"replaced": replaced,
|
|
"skipped": skipped,
|
|
"float8": float8_layers,
|
|
"nvfp4": nvfp4_layers,
|
|
"replaced_bases": replaced_bases,
|
|
"skipped_unresolved": skipped_unresolved,
|
|
"skipped_type_counts": skipped_type_counts,
|
|
}
|
|
|
|
def _normalize_legacy_scaled_fp8_weights(sd: dict) -> tuple[dict, dict]:
|
|
converted = dict(sd)
|
|
migrated = 0
|
|
quant_bases = set()
|
|
for key in list(converted.keys()):
|
|
if key.endswith(".scale_weight"):
|
|
base = key[: -len(".scale_weight")]
|
|
converted[f"{base}.weight_scale"] = converted.pop(key)
|
|
quant_bases.add(base)
|
|
migrated += 1
|
|
continue
|
|
if key.endswith(".scale_input"):
|
|
base = key[: -len(".scale_input")]
|
|
converted[f"{base}.input_scale"] = converted.pop(key)
|
|
quant_bases.add(base)
|
|
migrated += 1
|
|
|
|
if migrated == 0:
|
|
return sd, {"migrated": 0, "created_quant_entries": 0}
|
|
|
|
created_quant_entries = 0
|
|
for base in quant_bases:
|
|
weight_key = f"{base}.weight"
|
|
scale_key = f"{base}.weight_scale"
|
|
if weight_key not in converted or scale_key not in converted:
|
|
continue
|
|
if f"{base}.comfy_quant" in converted:
|
|
continue
|
|
weight = converted[weight_key]
|
|
fmt = None
|
|
fp8_e4m3 = getattr(torch, "float8_e4m3fn", None)
|
|
fp8_e5m2 = getattr(torch, "float8_e5m2", None)
|
|
if fp8_e4m3 is not None and getattr(weight, "dtype", None) == fp8_e4m3:
|
|
fmt = "float8_e4m3fn"
|
|
elif fp8_e5m2 is not None and getattr(weight, "dtype", None) == fp8_e5m2:
|
|
fmt = "float8_e5m2"
|
|
if fmt is None:
|
|
continue
|
|
payload = torch.tensor(list(json.dumps({"format": fmt}).encode("utf-8")), dtype=torch.uint8)
|
|
converted[f"{base}.comfy_quant"] = payload
|
|
created_quant_entries += 1
|
|
|
|
converted.pop("scaled_fp8", None)
|
|
return converted, {"migrated": migrated, "created_quant_entries": created_quant_entries}
|
|
|
|
def _synthesize_native_fp8_quant_entries(sd: dict, component_module=None) -> tuple[dict, dict]:
|
|
converted = dict(sd)
|
|
created = 0
|
|
fmt_counts = {"float8_e4m3fn": 0, "float8_e5m2": 0}
|
|
skipped_non_linear = 0
|
|
skipped_unresolved = 0
|
|
fp8_e4m3 = getattr(torch, "float8_e4m3fn", None)
|
|
fp8_e5m2 = getattr(torch, "float8_e5m2", None)
|
|
payload_cache = {}
|
|
|
|
for key, value in list(converted.items()):
|
|
if not key.endswith(".weight"):
|
|
continue
|
|
base = key[: -len(".weight")]
|
|
if f"{base}.comfy_quant" in converted:
|
|
continue
|
|
dtype = getattr(value, "dtype", None)
|
|
fmt = None
|
|
if fp8_e4m3 is not None and dtype == fp8_e4m3:
|
|
fmt = "float8_e4m3fn"
|
|
elif fp8_e5m2 is not None and dtype == fp8_e5m2:
|
|
fmt = "float8_e5m2"
|
|
if fmt is None:
|
|
continue
|
|
|
|
if component_module is not None:
|
|
try:
|
|
target = _resolve_module(component_module, base)
|
|
except Exception:
|
|
skipped_unresolved += 1
|
|
continue
|
|
if not _is_linear_like_module(target):
|
|
skipped_non_linear += 1
|
|
continue
|
|
|
|
payload = payload_cache.get(fmt)
|
|
if payload is None:
|
|
payload = torch.tensor(list(json.dumps({"format": fmt}).encode("utf-8")), dtype=torch.uint8)
|
|
payload_cache[fmt] = payload
|
|
converted[f"{base}.comfy_quant"] = payload
|
|
created += 1
|
|
fmt_counts[fmt] += 1
|
|
|
|
return converted, {
|
|
"created": created,
|
|
"skipped_non_linear": skipped_non_linear,
|
|
"skipped_unresolved": skipped_unresolved,
|
|
**fmt_counts,
|
|
}
|
|
|
|
def _dequantize_comfy_mixed_weights(sd: dict) -> tuple[dict, dict]:
|
|
quant_entries = {}
|
|
for key in list(sd.keys()):
|
|
if not key.endswith(".comfy_quant"):
|
|
continue
|
|
base = key[: -len(".comfy_quant")]
|
|
conf = _decode_comfy_quant_entry(sd[key])
|
|
if isinstance(conf, dict):
|
|
quant_entries[base] = conf
|
|
|
|
if not quant_entries:
|
|
return sd, {"layers": 0, "float8": 0, "nvfp4": 0}
|
|
|
|
converted = dict(sd)
|
|
stats = {"layers": 0, "float8": 0, "nvfp4": 0}
|
|
for base, conf in quant_entries.items():
|
|
fmt = str(conf.get("format", "")).lower()
|
|
weight_key = f"{base}.weight"
|
|
if weight_key not in converted:
|
|
continue
|
|
|
|
if fmt in ("float8_e4m3fn", "float8_e5m2"):
|
|
weight = converted[weight_key].float()
|
|
scale = converted.get(f"{base}.weight_scale", None)
|
|
if scale is not None:
|
|
weight = weight * scale.float()
|
|
converted[weight_key] = weight.to(torch.bfloat16)
|
|
stats["float8"] += 1
|
|
stats["layers"] += 1
|
|
elif fmt == "nvfp4":
|
|
packed = converted[weight_key]
|
|
block_scale = converted.get(f"{base}.weight_scale", None)
|
|
tensor_scale = converted.get(f"{base}.weight_scale_2", None)
|
|
if block_scale is None or tensor_scale is None:
|
|
raise RuntimeError(
|
|
f"NVFP4 layer '{base}' is missing weight_scale/weight_scale_2."
|
|
)
|
|
deq = _decode_fp4_e2m1_packed_u8(packed)
|
|
per_block = _from_blocked_scales(block_scale.float())
|
|
expanded_scale = per_block.repeat_interleave(16, dim=1) * tensor_scale.float()
|
|
if deq.shape != expanded_scale.shape:
|
|
raise RuntimeError(
|
|
f"NVFP4 shape mismatch in '{base}': values={tuple(deq.shape)} scales={tuple(expanded_scale.shape)}"
|
|
)
|
|
converted[weight_key] = (deq * expanded_scale).to(torch.bfloat16)
|
|
stats["nvfp4"] += 1
|
|
stats["layers"] += 1
|
|
else:
|
|
raise RuntimeError(f"Unsupported Comfy quant format '{fmt}' in '{base}'.")
|
|
|
|
# Remove quant side tensors after dequantization.
|
|
for suffix in quant_side_suffixes:
|
|
converted.pop(f"{base}{suffix}", None)
|
|
|
|
# Remove legacy global scaled-fp8 marker if present.
|
|
converted.pop("scaled_fp8", None)
|
|
return converted, stats
|
|
|
|
def _dequantize_selected_comfy_bases(sd: dict, selected_bases: set[str]) -> tuple[dict, dict]:
|
|
if not selected_bases:
|
|
return sd, {"selected": 0, "dequantized": 0, "float8": 0, "nvfp4": 0, "skipped": 0}
|
|
|
|
converted = dict(sd)
|
|
stats = {
|
|
"selected": len(selected_bases),
|
|
"dequantized": 0,
|
|
"float8": 0,
|
|
"nvfp4": 0,
|
|
"skipped": 0,
|
|
}
|
|
|
|
for base in sorted(selected_bases):
|
|
conf = _decode_comfy_quant_entry(converted.get(f"{base}.comfy_quant"))
|
|
weight_key = f"{base}.weight"
|
|
if not isinstance(conf, dict) or weight_key not in converted:
|
|
stats["skipped"] += 1
|
|
continue
|
|
|
|
fmt = str(conf.get("format", "")).lower()
|
|
try:
|
|
if fmt in ("float8_e4m3fn", "float8_e5m2"):
|
|
weight = converted[weight_key].float()
|
|
scale = converted.get(f"{base}.weight_scale", None)
|
|
if scale is not None:
|
|
weight = weight * scale.float()
|
|
converted[weight_key] = weight.to(torch.bfloat16)
|
|
stats["float8"] += 1
|
|
stats["dequantized"] += 1
|
|
elif fmt == "nvfp4":
|
|
packed = converted[weight_key]
|
|
block_scale = converted.get(f"{base}.weight_scale", None)
|
|
tensor_scale = converted.get(f"{base}.weight_scale_2", None)
|
|
if block_scale is None or tensor_scale is None:
|
|
stats["skipped"] += 1
|
|
continue
|
|
deq = _decode_fp4_e2m1_packed_u8(packed)
|
|
per_block = _from_blocked_scales(block_scale.float())
|
|
expanded_scale = per_block.repeat_interleave(16, dim=1) * tensor_scale.float()
|
|
if deq.shape != expanded_scale.shape:
|
|
stats["skipped"] += 1
|
|
continue
|
|
converted[weight_key] = (deq * expanded_scale).to(torch.bfloat16)
|
|
stats["nvfp4"] += 1
|
|
stats["dequantized"] += 1
|
|
else:
|
|
stats["skipped"] += 1
|
|
continue
|
|
except Exception:
|
|
stats["skipped"] += 1
|
|
continue
|
|
|
|
# Remove quant-side tensors for layers now running eager dense weights.
|
|
for suffix in quant_side_suffixes:
|
|
converted.pop(f"{base}{suffix}", None)
|
|
|
|
return converted, stats
|
|
|
|
state_dict, legacy_stats = _normalize_legacy_scaled_fp8_weights(state_dict)
|
|
if legacy_stats.get("migrated", 0) > 0:
|
|
print(
|
|
f"[Z-Image POC] Normalized legacy scaled FP8 keys for {component_name}: "
|
|
f"migrated={legacy_stats['migrated']}, quant_entries={legacy_stats['created_quant_entries']}."
|
|
)
|
|
|
|
model_keys = set(component.state_dict().keys())
|
|
probe_source_key_count = len(state_dict)
|
|
direct_match_count = 0
|
|
for key in state_dict.keys():
|
|
if key in model_keys:
|
|
direct_match_count += 1
|
|
print(
|
|
f"[Z-Image POC] {component_name} file override key probe: "
|
|
f"source={probe_source_key_count}, direct_matches={direct_match_count}, model_keys={len(model_keys)}"
|
|
)
|
|
remapped = None
|
|
runtime_quant_enabled = (
|
|
component_name in ("text_encoder", "transformer")
|
|
and _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_QUANT", "1")
|
|
)
|
|
runtime_stats = {"layers": 0, "replaced": 0, "skipped": 0, "float8": 0, "nvfp4": 0}
|
|
if runtime_quant_enabled:
|
|
remapped_candidate = _remap_state_dict_to_model_keys(
|
|
state_dict,
|
|
model_keys,
|
|
f"{component_name}-file-override-runtime",
|
|
verbose=True,
|
|
)
|
|
remapped_candidate, synth_stats = _synthesize_native_fp8_quant_entries(
|
|
remapped_candidate, component_module=component
|
|
)
|
|
if synth_stats["created"] > 0:
|
|
print(
|
|
f"[Z-Image POC] Synthesized native FP8 quant entries for {component_name}: "
|
|
f"layers={synth_stats['created']}, fp8_e4m3={synth_stats['float8_e4m3fn']}, "
|
|
f"fp8_e5m2={synth_stats['float8_e5m2']}."
|
|
)
|
|
elif synth_stats.get("skipped_non_linear", 0) > 0:
|
|
print(
|
|
f"[Z-Image POC] Native FP8 synth skipped non-linear layers for {component_name}: "
|
|
f"skipped_non_linear={synth_stats['skipped_non_linear']}."
|
|
)
|
|
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
|
|
runtime_stats["backend"] = "runtime"
|
|
if runtime_stats["layers"] > 0 and runtime_stats["replaced"] > 0:
|
|
replaced_bases = runtime_stats.get("replaced_bases", set())
|
|
quant_bases = {
|
|
key[: -len(".comfy_quant")]
|
|
for key in remapped_candidate.keys()
|
|
if key.endswith(".comfy_quant")
|
|
}
|
|
unmapped_bases = quant_bases - set(replaced_bases)
|
|
remapped_candidate, unmapped_stats = _dequantize_selected_comfy_bases(
|
|
remapped_candidate, unmapped_bases
|
|
)
|
|
runtime_stats["unmapped"] = len(unmapped_bases)
|
|
runtime_stats["unmapped_dequantized"] = unmapped_stats.get("dequantized", 0)
|
|
runtime_stats["unmapped_skipped"] = unmapped_stats.get("skipped", 0)
|
|
remapped = remapped_candidate
|
|
if runtime_stats["replaced"] >= runtime_stats["layers"]:
|
|
print(
|
|
f"[Z-Image POC] Runtime Comfy quant enabled for {component_name}: "
|
|
f"layers={runtime_stats['layers']}, fp8={runtime_stats['float8']}, "
|
|
f"nvfp4={runtime_stats['nvfp4']}, backend={runtime_stats.get('backend', 'runtime')}."
|
|
)
|
|
else:
|
|
skipped_types = runtime_stats.get("skipped_type_counts", {})
|
|
skipped_type_summary = ""
|
|
if skipped_types:
|
|
top = sorted(skipped_types.items(), key=lambda kv: kv[1], reverse=True)[:3]
|
|
skipped_type_summary = ", skipped_types=" + ",".join(f"{k}:{v}" for k, v in top)
|
|
print(
|
|
f"[Z-Image POC] Runtime Comfy quant partially mapped for {component_name} "
|
|
f"(replaced={runtime_stats['replaced']}/{runtime_stats['layers']}, "
|
|
f"unmapped_dequantized={runtime_stats.get('unmapped_dequantized', 0)}, "
|
|
f"skipped_unresolved={runtime_stats.get('skipped_unresolved', 0)}"
|
|
f"{skipped_type_summary}, "
|
|
"unmapped layers keep eager load path)."
|
|
)
|
|
|
|
if remapped is None:
|
|
state_dict, quant_stats = _dequantize_comfy_mixed_weights(state_dict)
|
|
if quant_stats.get("layers", 0) > 0:
|
|
print(
|
|
f"[Z-Image POC] Dequantized Comfy mixed weights for {component_name}: "
|
|
f"layers={quant_stats['layers']}, fp8={quant_stats['float8']}, nvfp4={quant_stats['nvfp4']}."
|
|
)
|
|
remapped = _remap_state_dict_to_model_keys(
|
|
state_dict,
|
|
model_keys,
|
|
f"{component_name}-file-override",
|
|
verbose=True,
|
|
)
|
|
|
|
remapped_match_count = 0
|
|
for key in remapped.keys():
|
|
if key in model_keys:
|
|
remapped_match_count += 1
|
|
continue
|
|
for suffix in quant_side_suffixes:
|
|
if key.endswith(suffix):
|
|
weight_key = f"{key[: -len(suffix)]}.weight"
|
|
if weight_key in model_keys:
|
|
remapped_match_count += 1
|
|
break
|
|
source_key_count = len(remapped)
|
|
|
|
precheck_ratio = remapped_match_count / float(max(source_key_count, 1))
|
|
if component_name in ("vae", "text_encoder"):
|
|
if precheck_ratio < 0.35:
|
|
print(
|
|
f"[Z-Image POC] Skipping incompatible {component_name} override file '{source_label or os.path.basename(file_path)}' "
|
|
f"(precheck remap_match={precheck_ratio:.1%}); using model default {component_name}."
|
|
)
|
|
state_dict.clear()
|
|
remapped.clear()
|
|
return
|
|
|
|
missing, unexpected = _apply_component_state_dict(
|
|
component,
|
|
remapped,
|
|
label=f"{component_name} file override ({source_label or os.path.basename(file_path)})",
|
|
missing_limit=None,
|
|
unexpected_limit=None,
|
|
)
|
|
# Guardrail: avoid silent garbage generations when incompatible quantized files are selected.
|
|
if model_keys:
|
|
missing_ratio = len(missing) / float(len(model_keys))
|
|
else:
|
|
missing_ratio = 0.0
|
|
unexpected_ratio = len(unexpected) / float(max(len(remapped), 1))
|
|
remap_ratio = remapped_match_count / float(max(source_key_count, 1))
|
|
if remap_ratio < 0.65 or missing_ratio > 0.35 or unexpected_ratio > 0.35:
|
|
if component_name in ("vae", "text_encoder"):
|
|
print(
|
|
f"[Z-Image POC] Skipping incompatible {component_name} override file '{source_label or os.path.basename(file_path)}' "
|
|
f"(remap_match={remap_ratio:.1%}, missing={len(missing)} ({missing_ratio:.1%}), "
|
|
f"unexpected={len(unexpected)} ({unexpected_ratio:.1%})); using model default {component_name}."
|
|
)
|
|
state_dict.clear()
|
|
remapped.clear()
|
|
return
|
|
raise RuntimeError(
|
|
f"Incompatible {component_name} override file '{source_label or os.path.basename(file_path)}': "
|
|
f"remap_match={remap_ratio:.1%}, missing={len(missing)} ({missing_ratio:.1%}), "
|
|
f"unexpected={len(unexpected)} ({unexpected_ratio:.1%}). "
|
|
"Use a compatible full-precision component or the default component folder."
|
|
)
|
|
print(f"[Z-Image POC] Using override {component_name} file: {source_label or file_path}")
|
|
state_dict.clear()
|
|
remapped.clear()
|
|
|
|
|
|
def _load_pipeline(
|
|
source_kind: str,
|
|
source_path: str,
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
):
|
|
cache_key = _pipeline_cache_key(
|
|
source_kind,
|
|
source_path,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
)
|
|
if cache_key in _PIPELINE_CACHE:
|
|
pipeline, generator_device, used_offload = _PIPELINE_CACHE[cache_key]
|
|
if _pipeline_has_meta_tensors(pipeline):
|
|
print("[Z-Image POC] Cached pipeline has meta tensors, rebuilding pipeline.")
|
|
_drop_cache_entry(cache_key)
|
|
return _load_pipeline(
|
|
source_kind,
|
|
source_path,
|
|
flavor,
|
|
checkpoint_folders,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
)
|
|
current_profile = _zimage_perf_profile()
|
|
cached_profile = getattr(pipeline, "_zimage_perf_profile", "safe")
|
|
if current_profile != cached_profile:
|
|
device, _ = _pick_device_and_dtype()
|
|
generator_device, used_offload = _prepare_pipeline_memory_mode(pipeline, device)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
# Keep cached memory mode for throughput; mode hardening is handled at load/OOM time.
|
|
return _PIPELINE_CACHE[cache_key]
|
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
_ensure_zimage_runtime_compatibility()
|
|
prefer_single_file_aux_weights = os.environ.get("FOOOCUS_ZIMAGE_LOAD_AIO_AUX", "").strip().lower() in (
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
"on",
|
|
)
|
|
|
|
zimage_allow_fp16 = _detect_zimage_allow_fp16(source_kind, source_path)
|
|
device, dtype = _pick_device_and_dtype(zimage_allow_fp16=zimage_allow_fp16)
|
|
|
|
if source_kind == "directory":
|
|
pipeline = _call_with_dtype_compat(
|
|
DiffusionPipeline.from_pretrained,
|
|
dtype,
|
|
{
|
|
"pretrained_model_name_or_path": source_path,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
},
|
|
"DiffusionPipeline.from_pretrained(directory)",
|
|
)
|
|
elif source_kind == "single_file":
|
|
local_config, tried_config_only_text_encoder = _ensure_single_file_component_dir(
|
|
flavor, checkpoint_folders, source_path
|
|
)
|
|
|
|
split_error = None
|
|
native_error = None
|
|
pipeline = None
|
|
is_fp8_single_file = _is_likely_fp8_single_file(source_path)
|
|
|
|
# Forge-like priority: let the framework load the single-file checkpoint natively first.
|
|
try:
|
|
if hasattr(DiffusionPipeline, "from_single_file") and not is_fp8_single_file:
|
|
native_kwargs = dict(
|
|
config=local_config,
|
|
local_files_only=True,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
pipeline = _call_with_dtype_compat(
|
|
lambda **kwargs: DiffusionPipeline.from_single_file(source_path, **kwargs),
|
|
dtype,
|
|
native_kwargs,
|
|
"DiffusionPipeline.from_single_file",
|
|
)
|
|
if pipeline is not None and _pipeline_has_meta_tensors(pipeline):
|
|
raise RuntimeError("native single-file produced meta tensors")
|
|
elif is_fp8_single_file:
|
|
print("[Z-Image POC] FP8 checkpoint detected, skipping native single-file loader to reduce RAM spikes.")
|
|
except Exception as e:
|
|
native_error = e
|
|
print(f"[Z-Image POC] Native single-file loader fallback due to: {e}")
|
|
_cleanup_memory(cuda=True)
|
|
|
|
# Fallback: split-loader assembly.
|
|
if pipeline is None:
|
|
try:
|
|
pipeline = _build_pipeline_from_single_file_components(
|
|
local_config,
|
|
source_path,
|
|
dtype,
|
|
prefer_single_file_aux_weights=prefer_single_file_aux_weights,
|
|
)
|
|
if pipeline is not None and _pipeline_has_meta_tensors(pipeline):
|
|
raise RuntimeError("split-loader produced meta tensors")
|
|
except Exception as e:
|
|
split_error = e
|
|
print(f"[Z-Image POC] Split-loader fallback due to: {e}")
|
|
_cleanup_memory(cuda=True)
|
|
|
|
if pipeline is None and split_error is not None:
|
|
# Legacy fallback path.
|
|
try:
|
|
pipeline = _call_with_dtype_compat(
|
|
DiffusionPipeline.from_pretrained,
|
|
dtype,
|
|
{
|
|
"pretrained_model_name_or_path": local_config,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
},
|
|
"DiffusionPipeline.from_pretrained(local_config)",
|
|
)
|
|
except Exception:
|
|
if not tried_config_only_text_encoder:
|
|
raise split_error
|
|
# Fallback: some backends require full text_encoder files even when AIO contains weights.
|
|
repo_id = _repo_for_flavor(flavor)
|
|
_download_repo_components(
|
|
repo_id,
|
|
local_config,
|
|
patterns=["text_encoder/*"],
|
|
missing=["text_encoder(fallback-full)"],
|
|
)
|
|
try:
|
|
pipeline = _call_with_dtype_compat(
|
|
DiffusionPipeline.from_pretrained,
|
|
dtype,
|
|
{
|
|
"pretrained_model_name_or_path": local_config,
|
|
"local_files_only": True,
|
|
"low_cpu_mem_usage": True,
|
|
},
|
|
"DiffusionPipeline.from_pretrained(local_config-fallback)",
|
|
)
|
|
except Exception:
|
|
raise split_error
|
|
_load_transformer_weights_from_single_file(source_path, pipeline)
|
|
if pipeline is None:
|
|
if split_error is not None:
|
|
raise split_error
|
|
if native_error is not None:
|
|
raise native_error
|
|
raise RuntimeError("Failed to build Z-Image pipeline from single-file checkpoint.")
|
|
else:
|
|
raise ValueError(f"Unsupported source kind: {source_kind}")
|
|
|
|
if _pipeline_has_meta_tensors(pipeline):
|
|
raise RuntimeError("Z-Image pipeline contains meta tensors after load.")
|
|
|
|
if text_encoder_override is not None:
|
|
_load_component_override(pipeline, "text_encoder", text_encoder_override, dtype)
|
|
if vae_override is not None:
|
|
_load_component_override(pipeline, "vae", vae_override, dtype)
|
|
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
generator_device, used_offload = _prepare_pipeline_memory_mode(pipeline, device)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
return _PIPELINE_CACHE[cache_key]
|
|
|
|
|
|
def _run_pipeline_call(pipeline, call_kwargs: dict):
|
|
kwargs = dict(call_kwargs)
|
|
optional_drop_order = [
|
|
"cfg_normalization",
|
|
"cfg_truncation",
|
|
"max_sequence_length",
|
|
"negative_prompt",
|
|
]
|
|
for _ in range(len(optional_drop_order) + 1):
|
|
try:
|
|
return pipeline(**kwargs)
|
|
except TypeError:
|
|
dropped = False
|
|
for key in optional_drop_order:
|
|
if key in kwargs:
|
|
kwargs.pop(key, None)
|
|
dropped = True
|
|
break
|
|
if not dropped:
|
|
raise
|
|
|
|
|
|
def _maybe_prewarm_pipeline(
|
|
pipeline,
|
|
generator_device: str,
|
|
flavor: str,
|
|
prewarm_width: Optional[int] = None,
|
|
prewarm_height: Optional[int] = None,
|
|
prewarm_max_sequence_length: Optional[int] = None,
|
|
) -> None:
|
|
if not _zimage_prewarm_enabled():
|
|
return
|
|
if bool(getattr(pipeline, "_zimage_prewarm_done", False)):
|
|
return
|
|
|
|
# Mark as attempted to avoid repeated startup penalties if warmup fails.
|
|
pipeline._zimage_prewarm_done = True
|
|
pipeline._zimage_prewarm_error = None
|
|
|
|
import torch
|
|
|
|
try:
|
|
steps = _zimage_prewarm_steps()
|
|
width, height = _zimage_prewarm_size(
|
|
default_width=max(256, int(prewarm_width)) if prewarm_width is not None else 832,
|
|
default_height=max(256, int(prewarm_height)) if prewarm_height is not None else 1216,
|
|
)
|
|
default_max_seq = 64 if flavor == "turbo" else 128
|
|
if prewarm_max_sequence_length is None:
|
|
max_sequence_length = default_max_seq
|
|
else:
|
|
max_sequence_length = max(32, int(prewarm_max_sequence_length))
|
|
prompt = os.environ.get("FOOOCUS_ZIMAGE_PREWARM_PROMPT", "").strip() or "portrait photo"
|
|
negative_prompt = os.environ.get("FOOOCUS_ZIMAGE_PREWARM_NEGATIVE_PROMPT", "").strip()
|
|
guidance = 1.0
|
|
use_cfg = guidance > 1.0
|
|
|
|
if generator_device == "cuda":
|
|
profile = _zimage_perf_profile()
|
|
_maybe_preemptive_cuda_cleanup_before_generation(pipeline, profile=profile)
|
|
|
|
started = time.perf_counter()
|
|
print(
|
|
f"[Z-Image POC] Prewarm start: steps={steps}, size={width}x{height}, "
|
|
f"max_seq={max_sequence_length}, device={generator_device}."
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
_prepare_granular_prompt_encode(pipeline, generator_device=generator_device)
|
|
pos, neg = pipeline.encode_prompt(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt if use_cfg else None,
|
|
do_classifier_free_guidance=use_cfg,
|
|
device=generator_device,
|
|
max_sequence_length=max_sequence_length,
|
|
)
|
|
prompt_embeds = [x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in pos]
|
|
negative_prompt_embeds = (
|
|
[x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in neg] if neg else []
|
|
)
|
|
generator = torch.Generator(device=generator_device).manual_seed(1)
|
|
_prepare_granular_pipeline_call(pipeline, generator_device=generator_device, stage="prewarm_call")
|
|
output = _run_pipeline_call(
|
|
pipeline,
|
|
dict(
|
|
prompt=None,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=steps,
|
|
guidance_scale=guidance,
|
|
generator=generator,
|
|
num_images_per_prompt=1,
|
|
cfg_normalization=False,
|
|
cfg_truncation=1.0,
|
|
max_sequence_length=max_sequence_length,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
),
|
|
)
|
|
# Force lazy decode/materialization.
|
|
_ = output.images[0]
|
|
del output
|
|
del prompt_embeds
|
|
del negative_prompt_embeds
|
|
del generator
|
|
_park_granular_components(pipeline, generator_device=generator_device, stage="prewarm_idle")
|
|
|
|
elapsed = time.perf_counter() - started
|
|
print(f"[Z-Image POC] Prewarm complete in {elapsed:.2f}s.")
|
|
except Exception as e:
|
|
pipeline._zimage_prewarm_error = str(e)
|
|
print(f"[Z-Image POC] Prewarm failed (ignored): {e}")
|
|
finally:
|
|
try:
|
|
if hasattr(pipeline, "maybe_free_model_hooks"):
|
|
pipeline.maybe_free_model_hooks()
|
|
except Exception:
|
|
pass
|
|
_park_granular_components(pipeline, generator_device=generator_device, stage="prewarm_final")
|
|
if generator_device == "cuda":
|
|
_cleanup_memory(cuda=True, aggressive=False)
|
|
|
|
|
|
def _generate_zimage_impl(
|
|
source_kind: str,
|
|
source_path: str,
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
seeds: Optional[list[int]] = None,
|
|
shift: float = 3.0,
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
return_images: bool = False,
|
|
_use_alt_path: bool = False,
|
|
):
|
|
import torch
|
|
|
|
stage_timers = _zimage_stage_timers_enabled()
|
|
total_start = time.perf_counter()
|
|
stage_times: dict[str, float] = {}
|
|
embed_cache_hit = False
|
|
generation_attempts = 0
|
|
error_name = ""
|
|
|
|
stage_start = time.perf_counter()
|
|
resolved_text_encoder_override = resolve_zimage_component_path(
|
|
text_encoder_override, "text_encoder", checkpoint_folders
|
|
)
|
|
resolved_vae_override = resolve_zimage_component_path(vae_override, "vae", checkpoint_folders)
|
|
stage_times["resolve_overrides"] = time.perf_counter() - stage_start
|
|
cache_key = _pipeline_cache_key(
|
|
source_kind,
|
|
source_path,
|
|
text_encoder_override=resolved_text_encoder_override,
|
|
vae_override=resolved_vae_override,
|
|
)
|
|
profile = _zimage_perf_profile()
|
|
stage_start = time.perf_counter()
|
|
pipeline, generator_device, used_offload = _load_pipeline(
|
|
source_kind,
|
|
source_path,
|
|
flavor,
|
|
checkpoint_folders,
|
|
text_encoder_override=resolved_text_encoder_override,
|
|
vae_override=resolved_vae_override,
|
|
)
|
|
if generator_device == "cuda" and _should_cleanup_cuda_cache(profile, had_oom=False, pipeline=pipeline):
|
|
_cleanup_memory(cuda=True, aggressive=False)
|
|
stage_times["pipeline_load"] = time.perf_counter() - stage_start
|
|
|
|
# Align scheduler shift with Forge-style "Shift" control when available.
|
|
try:
|
|
if hasattr(pipeline, "scheduler") and hasattr(pipeline.scheduler, "config"):
|
|
if hasattr(pipeline.scheduler.config, "shift"):
|
|
pipeline.scheduler.config.shift = float(shift)
|
|
if hasattr(pipeline.scheduler, "shift"):
|
|
pipeline.scheduler.shift = float(shift)
|
|
except Exception:
|
|
pass
|
|
|
|
# Keep turbo aligned with Z-Image pipeline defaults instead of a stricter local cap.
|
|
max_sequence_length = 512
|
|
use_cfg = guidance_scale > 1.0
|
|
allow_quality_fallback = _zimage_allow_quality_fallback()
|
|
|
|
hard_cap = max_sequence_length
|
|
if flavor == "turbo":
|
|
env_max_seq = os.environ.get("FOOOCUS_ZIMAGE_TURBO_MAX_SEQ", "").strip()
|
|
if env_max_seq:
|
|
try:
|
|
env_cap = max(64, int(env_max_seq))
|
|
hard_cap = min(hard_cap, env_cap)
|
|
except Exception:
|
|
pass
|
|
forced_max_seq = getattr(pipeline, "_zimage_forced_max_sequence_length", None)
|
|
if (not allow_quality_fallback) and forced_max_seq is not None:
|
|
forced_max_seq = None
|
|
pipeline._zimage_forced_max_sequence_length = None
|
|
if forced_max_seq is not None:
|
|
hard_cap = min(hard_cap, int(forced_max_seq))
|
|
|
|
stage_start = time.perf_counter()
|
|
max_sequence_length = hard_cap
|
|
if flavor == "turbo" and _truthy_env("FOOOCUS_ZIMAGE_DYNAMIC_MAX_SEQ", "1"):
|
|
max_sequence_length = _compute_auto_max_sequence_length(
|
|
pipeline=pipeline,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
use_cfg=use_cfg,
|
|
hard_cap=hard_cap,
|
|
)
|
|
|
|
# RAM-aware max_seq tuning before preflight: keep quality as high as possible while
|
|
# reducing host-RAM pressure from aggressive offload modes.
|
|
host_available_gb, host_total_gb = _system_ram_info_gb()
|
|
host_reserve_gb = _zimage_system_ram_reserve_gb()
|
|
host_usable_gb = max(0.0, host_available_gb - host_reserve_gb)
|
|
if flavor == "turbo" and host_available_gb > 0.0:
|
|
ram_cap = None
|
|
if host_usable_gb < 0.25:
|
|
ram_cap = 128
|
|
elif host_usable_gb < 0.50:
|
|
ram_cap = 192
|
|
elif host_usable_gb < 0.75:
|
|
ram_cap = 256
|
|
if ram_cap is not None and max_sequence_length > ram_cap:
|
|
max_sequence_length = ram_cap
|
|
pipeline._zimage_forced_max_sequence_length = ram_cap
|
|
print(
|
|
f"[Z-Image POC] Host RAM pressure ({host_available_gb:.2f}GB free / {host_total_gb:.2f}GB total, "
|
|
f"reserve={host_reserve_gb:.2f}GB) -> using max_sequence_length={ram_cap}."
|
|
)
|
|
|
|
if generator_device == "cuda" and allow_quality_fallback:
|
|
free_gb, total_gb = _cuda_mem_info_gb()
|
|
if max_sequence_length > 192 and free_gb > 0 and free_gb < 0.40:
|
|
max_sequence_length = 192
|
|
pipeline._zimage_forced_max_sequence_length = 192
|
|
print(
|
|
f"[Z-Image POC] Low free VRAM before generation ({free_gb:.2f}GB/{total_gb:.2f}GB), "
|
|
"using max_sequence_length=192."
|
|
)
|
|
if max_sequence_length > 160 and free_gb > 0 and free_gb < 0.25:
|
|
max_sequence_length = 160
|
|
pipeline._zimage_forced_max_sequence_length = 160
|
|
print(
|
|
f"[Z-Image POC] Very low free VRAM before generation ({free_gb:.2f}GB/{total_gb:.2f}GB), "
|
|
"using max_sequence_length=160."
|
|
)
|
|
|
|
_maybe_prewarm_pipeline(
|
|
pipeline,
|
|
generator_device=generator_device,
|
|
flavor=flavor,
|
|
prewarm_width=width,
|
|
prewarm_height=height,
|
|
prewarm_max_sequence_length=min(max_sequence_length, 128 if flavor == "turbo" else 192),
|
|
)
|
|
|
|
if generator_device == "cuda":
|
|
_maybe_preemptive_cuda_cleanup_before_generation(pipeline, profile=profile)
|
|
|
|
try:
|
|
generator_device, used_offload = _preflight_generation_memory_mode(
|
|
pipeline=pipeline,
|
|
cache_key=cache_key,
|
|
device="cuda" if generator_device == "cuda" else "cpu",
|
|
generator_device=generator_device,
|
|
used_offload=used_offload,
|
|
profile=profile,
|
|
width=width,
|
|
height=height,
|
|
max_sequence_length=max_sequence_length,
|
|
use_cfg=use_cfg,
|
|
flavor=flavor,
|
|
)
|
|
except Exception:
|
|
# Ensure preflight failures don't leave a poisoned cached pipeline for the next image.
|
|
_PIPELINE_CACHE.pop(cache_key, None)
|
|
_clear_prompt_cache_for_pipeline(cache_key)
|
|
raise
|
|
seed_list = [int(seed)]
|
|
if seeds:
|
|
parsed = []
|
|
for s in seeds:
|
|
try:
|
|
parsed.append(int(s))
|
|
except Exception:
|
|
continue
|
|
if parsed:
|
|
seed_list = parsed
|
|
|
|
generator = None
|
|
alt_latents_device_candidates: list[str] = []
|
|
alt_latents_device_index = 0
|
|
if _use_alt_path:
|
|
_ensure_alt_path_prerequisites(
|
|
pipeline=pipeline,
|
|
width=width,
|
|
height=height,
|
|
)
|
|
for candidate in (generator_device, "cpu"):
|
|
if candidate not in alt_latents_device_candidates:
|
|
alt_latents_device_candidates.append(candidate)
|
|
else:
|
|
if len(seed_list) <= 1:
|
|
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
|
|
else:
|
|
generator = [torch.Generator(device=generator_device).manual_seed(s) for s in seed_list]
|
|
stage_times["runtime_prep"] = time.perf_counter() - stage_start
|
|
|
|
neg_key = negative_prompt if use_cfg else ""
|
|
embed_cache_key = (
|
|
cache_key,
|
|
prompt,
|
|
neg_key,
|
|
int(max_sequence_length),
|
|
bool(use_cfg),
|
|
)
|
|
|
|
stage_start = time.perf_counter()
|
|
prompt_embeds = None
|
|
negative_prompt_embeds = None
|
|
cached_embeds = _PROMPT_EMBED_CACHE.get(embed_cache_key, None)
|
|
embed_cache_hit = cached_embeds is not None
|
|
if cached_embeds is not None:
|
|
cached_pos, cached_neg = cached_embeds
|
|
prompt_embeds = [x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in cached_pos]
|
|
if cached_neg:
|
|
negative_prompt_embeds = [x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in cached_neg]
|
|
else:
|
|
negative_prompt_embeds = []
|
|
else:
|
|
_prepare_granular_prompt_encode(pipeline, generator_device=generator_device)
|
|
try:
|
|
pos, neg = pipeline.encode_prompt(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt if use_cfg else None,
|
|
do_classifier_free_guidance=use_cfg,
|
|
device=generator_device,
|
|
max_sequence_length=max_sequence_length,
|
|
)
|
|
finally:
|
|
_park_granular_components(pipeline, generator_device=generator_device, stage="prompt_idle")
|
|
cpu_pos = [x.detach().to("cpu", copy=True) for x in pos]
|
|
cpu_neg = [x.detach().to("cpu", copy=True) for x in neg] if neg else []
|
|
_put_prompt_cache(embed_cache_key, (cpu_pos, cpu_neg))
|
|
prompt_embeds = [x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in cpu_pos]
|
|
if cpu_neg:
|
|
negative_prompt_embeds = [x.to(device=generator_device, dtype=pipeline.transformer.dtype) for x in cpu_neg]
|
|
else:
|
|
negative_prompt_embeds = []
|
|
stage_times["prompt_encode"] = time.perf_counter() - stage_start
|
|
|
|
call_kwargs = dict(
|
|
prompt=None,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
num_images_per_prompt=max(1, len(seed_list)),
|
|
cfg_normalization=False,
|
|
cfg_truncation=1.0,
|
|
max_sequence_length=max_sequence_length,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
)
|
|
if _use_alt_path:
|
|
_set_generation_random_source(
|
|
call_kwargs=call_kwargs,
|
|
seed_list=seed_list,
|
|
pipeline=pipeline,
|
|
generator_device=generator_device,
|
|
use_alt_path=True,
|
|
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
|
|
)
|
|
else:
|
|
call_kwargs["generator"] = generator
|
|
print(
|
|
f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, "
|
|
f"size={call_kwargs['width']}x{call_kwargs['height']}, max_seq={max_sequence_length}, offload={used_offload}, "
|
|
f"batch={call_kwargs['num_images_per_prompt']}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}"
|
|
)
|
|
|
|
output = None
|
|
call_start = time.perf_counter()
|
|
black_retry_used = False
|
|
try:
|
|
retry_caps = []
|
|
retry_sizes = []
|
|
if flavor == "turbo" and allow_quality_fallback:
|
|
current_seq = int(call_kwargs.get("max_sequence_length", 256))
|
|
for candidate in (192, 160, 128, 96, 64, 32):
|
|
if current_seq > candidate:
|
|
retry_caps.append(candidate)
|
|
current_w = int(call_kwargs.get("width", width))
|
|
current_h = int(call_kwargs.get("height", height))
|
|
for scale in (0.85, 0.75, 0.625):
|
|
next_w = max(384, int((current_w * scale) // 64) * 64)
|
|
next_h = max(384, int((current_h * scale) // 64) * 64)
|
|
if next_w < current_w or next_h < current_h:
|
|
pair = (next_w, next_h)
|
|
if pair not in retry_sizes:
|
|
retry_sizes.append(pair)
|
|
|
|
max_attempts = 4 if (flavor == "turbo" and allow_quality_fallback) else (2 if flavor == "turbo" else 3)
|
|
for attempt in range(max_attempts):
|
|
generation_attempts = attempt + 1
|
|
try:
|
|
if _use_alt_path:
|
|
_set_generation_random_source(
|
|
call_kwargs=call_kwargs,
|
|
seed_list=seed_list,
|
|
pipeline=pipeline,
|
|
generator_device=generator_device,
|
|
use_alt_path=True,
|
|
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
|
|
)
|
|
_prepare_granular_pipeline_call(
|
|
pipeline,
|
|
generator_device=generator_device,
|
|
stage=f"pipeline_call_{attempt + 1}",
|
|
)
|
|
output = _run_pipeline_call(pipeline, call_kwargs)
|
|
if _zimage_black_image_retry_enabled() and not black_retry_used:
|
|
try:
|
|
candidates = list(getattr(output, "images", []) or [])
|
|
except Exception:
|
|
candidates = []
|
|
black_entries = []
|
|
for idx, candidate in enumerate(candidates):
|
|
try:
|
|
is_black, black_info = _is_suspected_black_image(candidate)
|
|
except Exception:
|
|
is_black, black_info = False, None
|
|
if is_black and black_info is not None:
|
|
black_entries.append((idx, black_info))
|
|
if black_entries:
|
|
first_black_idx, first_black_info = black_entries[0]
|
|
is_batch_black = len(candidates) > 1 and len(black_entries) == len(candidates)
|
|
if len(candidates) == 1 or is_batch_black:
|
|
black_retry_used = True
|
|
strategy = str(getattr(pipeline, "_zimage_xformers_strategy", "unknown"))
|
|
transformer = getattr(pipeline, "transformer", None)
|
|
transformer_dtype_obj = getattr(transformer, "dtype", None)
|
|
transformer_dtype = str(transformer_dtype_obj)
|
|
strict_fp16 = _zimage_strict_fp16_mode() and transformer_dtype_obj == torch.float16
|
|
if strict_fp16:
|
|
print(
|
|
f"[Z-Image POC] Suspected black output detected "
|
|
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
|
|
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
|
|
f"attn={strategy}, dtype={transformer_dtype}). Strict FP16 mode enabled; no fallback."
|
|
)
|
|
raise RuntimeError(
|
|
"Suspected black output in strict FP16 mode; refusing automatic fallback."
|
|
)
|
|
print(
|
|
f"[Z-Image POC] Suspected black output detected "
|
|
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
|
|
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
|
|
f"attn={strategy}, dtype={transformer_dtype}). Retrying once with safer runtime."
|
|
)
|
|
|
|
changed = []
|
|
if transformer is not None and hasattr(transformer, "set_attention_backend"):
|
|
# Flash can occasionally produce pathological outputs on some builds.
|
|
if "flash" in strategy:
|
|
for candidate_backend in ("xformers", "native"):
|
|
try:
|
|
transformer.set_attention_backend(candidate_backend)
|
|
pipeline._zimage_xformers_strategy = f"dispatch_backend:{candidate_backend}"
|
|
changed.append(f"attn={candidate_backend}")
|
|
break
|
|
except Exception:
|
|
continue
|
|
|
|
# If user forced fp16 and model behaves badly, attempt one safer dtype retry.
|
|
if transformer_dtype_obj == torch.float16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
quant_dtype_updates = 0
|
|
for module_name in ("transformer", "text_encoder", "vae"):
|
|
module = getattr(pipeline, module_name, None)
|
|
if module is not None and hasattr(module, "to"):
|
|
try:
|
|
module.to(dtype=torch.bfloat16)
|
|
except Exception:
|
|
pass
|
|
quant_dtype_updates += _retune_runtime_quant_modules_dtype(module, torch.bfloat16)
|
|
new_dtype = torch.bfloat16
|
|
call_kwargs["prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=new_dtype) for x in call_kwargs.get("prompt_embeds", [])
|
|
]
|
|
call_kwargs["negative_prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=new_dtype) for x in call_kwargs.get("negative_prompt_embeds", [])
|
|
]
|
|
changed.append("dtype=bf16")
|
|
if quant_dtype_updates:
|
|
changed.append(f"runtime_quant_dtype={quant_dtype_updates}")
|
|
|
|
if changed:
|
|
if generator_device == "cuda":
|
|
_cleanup_memory(cuda=True, aggressive=True)
|
|
_set_generation_random_source(
|
|
call_kwargs=call_kwargs,
|
|
seed_list=seed_list,
|
|
pipeline=pipeline,
|
|
generator_device=generator_device,
|
|
use_alt_path=_use_alt_path,
|
|
latents_device=alt_latents_device_candidates[alt_latents_device_index] if _use_alt_path and alt_latents_device_candidates else None,
|
|
)
|
|
original_output = output
|
|
try:
|
|
_prepare_granular_pipeline_call(
|
|
pipeline,
|
|
generator_device=generator_device,
|
|
stage="black_retry",
|
|
)
|
|
output = _run_pipeline_call(pipeline, call_kwargs)
|
|
except Exception as retry_error:
|
|
output = original_output
|
|
print(
|
|
f"[Z-Image POC] Black-image retry failed ({retry_error}); keeping original output."
|
|
)
|
|
try:
|
|
retry_candidates = list(getattr(output, "images", []) or [])
|
|
retry_black_any = False
|
|
retry_info = None
|
|
for retry_image in retry_candidates:
|
|
retry_black, retry_info = _is_suspected_black_image(retry_image)
|
|
if retry_black:
|
|
retry_black_any = True
|
|
break
|
|
if retry_black_any and retry_info is not None:
|
|
print(
|
|
f"[Z-Image POC] Black-image retry remained near-black "
|
|
f"(mean={retry_info['mean']:.2f}, max={retry_info['max']:.0f})."
|
|
)
|
|
else:
|
|
print(
|
|
f"[Z-Image POC] Black-image retry recovered output using {', '.join(changed)}."
|
|
)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
print("[Z-Image POC] No safe retry knobs available; keeping original output.")
|
|
elif black_entries:
|
|
# For batches, retry only if every output is black. Mixed batches are preserved.
|
|
print(
|
|
f"[Z-Image POC] Batch output has {len(black_entries)}/{len(candidates)} near-black images; "
|
|
"keeping batch output."
|
|
)
|
|
break
|
|
except Exception as e:
|
|
msg = str(e).lower()
|
|
device_mismatch = (
|
|
"expected all tensors to be on the same device" in msg
|
|
and "cuda" in msg
|
|
and "cpu" in msg
|
|
)
|
|
if _use_alt_path and device_mismatch and attempt < (max_attempts - 1):
|
|
if alt_latents_device_candidates and (alt_latents_device_index + 1) < len(alt_latents_device_candidates):
|
|
alt_latents_device_index += 1
|
|
retry_latent_device = alt_latents_device_candidates[alt_latents_device_index]
|
|
print(
|
|
"[Z-Image POC] Alternate path device mismatch detected; "
|
|
f"retrying with latents on {retry_latent_device}."
|
|
)
|
|
continue
|
|
|
|
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if isinstance(deep_state, dict):
|
|
print(
|
|
"[Z-Image POC] Alternate path device mismatch persisted; "
|
|
"disabling deep patcher and retrying with non-deep offload."
|
|
)
|
|
pipeline._zimage_deep_patcher_blocked = True
|
|
_disable_deep_patcher_offload(pipeline, target_device="cpu")
|
|
try:
|
|
fallback_free_gb, fallback_total_gb = _cuda_mem_info_gb()
|
|
fallback_pressure = (fallback_free_gb / fallback_total_gb) if fallback_total_gb > 0 else 0.0
|
|
generator_device, used_offload = _apply_memory_mode(
|
|
pipeline=pipeline,
|
|
device="cuda",
|
|
target_mode="model_offload",
|
|
total_vram_gb=fallback_total_gb,
|
|
free_vram_gb=fallback_free_gb,
|
|
pressure=fallback_pressure,
|
|
profile=profile,
|
|
reason="alternate path device mismatch",
|
|
allow_relax=True,
|
|
)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
except Exception as fallback_error:
|
|
print(
|
|
"[Z-Image POC] Failed to switch from deep patcher after alternate-path device mismatch: "
|
|
f"{fallback_error}"
|
|
)
|
|
alt_latents_device_candidates = []
|
|
for candidate in (generator_device, "cpu"):
|
|
if candidate not in alt_latents_device_candidates:
|
|
alt_latents_device_candidates.append(candidate)
|
|
alt_latents_device_index = 0
|
|
call_kwargs["prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
|
for x in call_kwargs.get("prompt_embeds", [])
|
|
]
|
|
call_kwargs["negative_prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
|
for x in call_kwargs.get("negative_prompt_embeds", [])
|
|
]
|
|
continue
|
|
|
|
deep_generator_mismatch = (
|
|
"cannot generate a cpu tensor from a generator of type cuda" in msg
|
|
or ("generator of type cuda" in msg and "cpu tensor" in msg)
|
|
)
|
|
if (not _use_alt_path) and deep_generator_mismatch and generator_device == "cuda":
|
|
if attempt < (max_attempts - 1):
|
|
print(
|
|
"[Z-Image POC] Deep patcher runtime generator/device mismatch detected; "
|
|
"disabling deep patcher for this pipeline and retrying with non-deep offload."
|
|
)
|
|
pipeline._zimage_deep_patcher_blocked = True
|
|
_disable_deep_patcher_offload(pipeline, target_device="cpu")
|
|
try:
|
|
fallback_free_gb, fallback_total_gb = _cuda_mem_info_gb()
|
|
fallback_pressure = (fallback_free_gb / fallback_total_gb) if fallback_total_gb > 0 else 0.0
|
|
generator_device, used_offload = _apply_memory_mode(
|
|
pipeline=pipeline,
|
|
device="cuda",
|
|
target_mode="model_offload",
|
|
total_vram_gb=fallback_total_gb,
|
|
free_vram_gb=fallback_free_gb,
|
|
pressure=fallback_pressure,
|
|
profile=profile,
|
|
reason="deep patcher generator mismatch",
|
|
allow_relax=True,
|
|
)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
except Exception as fallback_error:
|
|
print(
|
|
"[Z-Image POC] Failed to switch from deep patcher after generator mismatch: "
|
|
f"{fallback_error}"
|
|
)
|
|
_set_generation_random_source(
|
|
call_kwargs=call_kwargs,
|
|
seed_list=seed_list,
|
|
pipeline=pipeline,
|
|
generator_device=generator_device,
|
|
use_alt_path=False,
|
|
)
|
|
call_kwargs["prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
|
for x in call_kwargs.get("prompt_embeds", [])
|
|
]
|
|
call_kwargs["negative_prompt_embeds"] = [
|
|
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
|
for x in call_kwargs.get("negative_prompt_embeds", [])
|
|
]
|
|
continue
|
|
|
|
xformers_mismatch = (
|
|
"xformersattnprocessor" in msg
|
|
or "cross_attention_kwargs" in msg
|
|
or "expanded size of the tensor" in msg
|
|
or "freqs_cis" in msg
|
|
)
|
|
if xformers_mismatch:
|
|
disabled = _disable_xformers_for_pipeline(
|
|
pipeline, reason="runtime mismatch with Z-Image attention kwargs"
|
|
)
|
|
if disabled and attempt < 2:
|
|
print("[Z-Image POC] Retrying after disabling xFormers for Z-Image compatibility.")
|
|
_cleanup_memory(cuda=(generator_device == "cuda"), aggressive=True)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
continue
|
|
|
|
if "out of memory" not in msg or generator_device != "cuda":
|
|
raise
|
|
if attempt >= (max_attempts - 1):
|
|
raise
|
|
|
|
print("[Z-Image POC] CUDA OOM detected, retrying with stricter offload mode.")
|
|
_cleanup_memory(cuda=True, aggressive=True)
|
|
pipeline._zimage_last_oom = True
|
|
|
|
try:
|
|
oom_free_gb, oom_total_gb = _cuda_mem_info_gb()
|
|
oom_pressure = (oom_free_gb / oom_total_gb) if oom_total_gb > 0 else 0.0
|
|
generator_device, used_offload = _apply_memory_mode(
|
|
pipeline=pipeline,
|
|
device="cuda",
|
|
target_mode="sequential_offload",
|
|
total_vram_gb=oom_total_gb,
|
|
free_vram_gb=oom_free_gb,
|
|
pressure=oom_pressure,
|
|
profile=profile,
|
|
reason="oom retry",
|
|
allow_relax=False,
|
|
)
|
|
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
|
except Exception as offload_error:
|
|
print(f"[Z-Image POC] OOM retry could not switch offload mode: {offload_error}")
|
|
if hasattr(pipeline, "enable_attention_slicing"):
|
|
pipeline.enable_attention_slicing("max")
|
|
if hasattr(pipeline, "enable_vae_slicing"):
|
|
pipeline.enable_vae_slicing()
|
|
if hasattr(pipeline, "enable_vae_tiling"):
|
|
pipeline.enable_vae_tiling()
|
|
|
|
if allow_quality_fallback:
|
|
lowered = False
|
|
if retry_caps:
|
|
next_cap = retry_caps.pop(0)
|
|
call_kwargs["max_sequence_length"] = next_cap
|
|
pipeline._zimage_forced_max_sequence_length = next_cap
|
|
print(
|
|
f"[Z-Image POC] Retrying with reduced max_sequence_length={next_cap} for lower VRAM usage."
|
|
)
|
|
lowered = True
|
|
if retry_sizes:
|
|
next_w, next_h = retry_sizes.pop(0)
|
|
call_kwargs["width"] = next_w
|
|
call_kwargs["height"] = next_h
|
|
print(
|
|
f"[Z-Image POC] Retrying with reduced resolution {next_w}x{next_h} for lower VRAM usage."
|
|
)
|
|
lowered = True
|
|
if not lowered:
|
|
print("[Z-Image POC] Retrying with same sequence length after memory cleanup.")
|
|
else:
|
|
print("[Z-Image POC] Retrying with same quality settings after memory cleanup.")
|
|
continue
|
|
|
|
if output is None:
|
|
raise RuntimeError("Z-Image generation failed after OOM retries.")
|
|
stage_times["pipeline_call"] = time.perf_counter() - call_start
|
|
|
|
stage_start = time.perf_counter()
|
|
images = list(getattr(output, "images", []) or [])
|
|
del output
|
|
stage_times["extract_image"] = time.perf_counter() - stage_start
|
|
if return_images:
|
|
return images
|
|
if not images:
|
|
raise RuntimeError("Z-Image pipeline returned no images.")
|
|
return images[0]
|
|
except Exception as e:
|
|
error_name = type(e).__name__
|
|
if "pipeline_call" not in stage_times:
|
|
stage_times["pipeline_call"] = time.perf_counter() - call_start
|
|
# Prevent poisoned/corrupted cache from breaking next generation request.
|
|
_PIPELINE_CACHE.pop(cache_key, None)
|
|
_clear_prompt_cache_for_pipeline(cache_key)
|
|
raise
|
|
finally:
|
|
cleanup_start = time.perf_counter()
|
|
_park_granular_components(pipeline, generator_device=generator_device, stage="final_idle")
|
|
try:
|
|
# Ensure accelerate offload hooks release device-resident weights between images.
|
|
if hasattr(pipeline, "maybe_free_model_hooks"):
|
|
pipeline.maybe_free_model_hooks()
|
|
del prompt_embeds
|
|
del negative_prompt_embeds
|
|
del generator
|
|
call_kwargs.clear()
|
|
except Exception:
|
|
pass
|
|
had_oom = bool(getattr(pipeline, "_zimage_last_oom", False))
|
|
pipeline._zimage_last_run_had_oom = had_oom
|
|
if hasattr(pipeline, "_zimage_last_oom"):
|
|
pipeline._zimage_last_oom = False
|
|
if generator_device == "cuda":
|
|
if _should_cleanup_cuda_cache(profile, had_oom=had_oom, pipeline=pipeline):
|
|
_cleanup_memory(cuda=True, aggressive=had_oom)
|
|
else:
|
|
_cleanup_memory(cuda=False, aggressive=had_oom)
|
|
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
|
if stage_timers and isinstance(deep_state, dict):
|
|
print(
|
|
"[Z-Image POC] Deep patcher stats: "
|
|
f"moves_to_gpu={int(deep_state.get('moves_to_gpu', 0))}, "
|
|
f"moves_to_cpu={int(deep_state.get('moves_to_cpu', 0))}, "
|
|
f"modules={len(deep_state.get('module_entries', ()))}."
|
|
)
|
|
stage_times["cleanup"] = time.perf_counter() - cleanup_start
|
|
if stage_timers:
|
|
total_elapsed = time.perf_counter() - total_start
|
|
status = "ok" if not error_name else f"error={error_name}"
|
|
embed_status = "hit" if embed_cache_hit else "miss"
|
|
print(
|
|
f"[Z-Image POC] Stage timings ({status}, embed_cache={embed_status}, attempts={generation_attempts}): "
|
|
f"resolve={_format_timing_ms(stage_times.get('resolve_overrides'))}, "
|
|
f"load={_format_timing_ms(stage_times.get('pipeline_load'))}, "
|
|
f"prepare={_format_timing_ms(stage_times.get('runtime_prep'))}, "
|
|
f"encode={_format_timing_ms(stage_times.get('prompt_encode'))}, "
|
|
f"infer={_format_timing_ms(stage_times.get('pipeline_call'))}, "
|
|
f"extract={_format_timing_ms(stage_times.get('extract_image'))}, "
|
|
f"cleanup={_format_timing_ms(stage_times.get('cleanup'))}, "
|
|
f"total={_format_timing_ms(total_elapsed)}"
|
|
)
|
|
|
|
|
|
def _generate_zimage_legacy(
|
|
source_kind: str,
|
|
source_path: str,
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
seeds: Optional[list[int]] = None,
|
|
shift: float = 3.0,
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
return_images: bool = False,
|
|
):
|
|
return _generate_zimage_impl(
|
|
source_kind=source_kind,
|
|
source_path=source_path,
|
|
flavor=flavor,
|
|
checkpoint_folders=checkpoint_folders,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
seeds=seeds,
|
|
shift=shift,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
return_images=return_images,
|
|
_use_alt_path=False,
|
|
)
|
|
|
|
|
|
def _generate_zimage_alt(
|
|
source_kind: str,
|
|
source_path: str,
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
seeds: Optional[list[int]] = None,
|
|
shift: float = 3.0,
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
return_images: bool = False,
|
|
):
|
|
return _generate_zimage_impl(
|
|
source_kind=source_kind,
|
|
source_path=source_path,
|
|
flavor=flavor,
|
|
checkpoint_folders=checkpoint_folders,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
seeds=seeds,
|
|
shift=shift,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
return_images=return_images,
|
|
_use_alt_path=True,
|
|
)
|
|
|
|
|
|
def generate_zimage(
|
|
source_kind: str,
|
|
source_path: str,
|
|
flavor: str,
|
|
checkpoint_folders: list[str],
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
seeds: Optional[list[int]] = None,
|
|
shift: float = 3.0,
|
|
text_encoder_override: Optional[str] = None,
|
|
vae_override: Optional[str] = None,
|
|
return_images: bool = False,
|
|
):
|
|
if _zimage_alt_path_enabled():
|
|
return _generate_zimage_alt(
|
|
source_kind=source_kind,
|
|
source_path=source_path,
|
|
flavor=flavor,
|
|
checkpoint_folders=checkpoint_folders,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
seeds=seeds,
|
|
shift=shift,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
return_images=return_images,
|
|
)
|
|
return _generate_zimage_legacy(
|
|
source_kind=source_kind,
|
|
source_path=source_path,
|
|
flavor=flavor,
|
|
checkpoint_folders=checkpoint_folders,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
seeds=seeds,
|
|
shift=shift,
|
|
text_encoder_override=text_encoder_override,
|
|
vae_override=vae_override,
|
|
return_images=return_images,
|
|
)
|