From 8af0646df397489298d77887d1c28131b6c2aa92 Mon Sep 17 00:00:00 2001 From: Developer Date: Tue, 17 Feb 2026 20:30:35 +0200 Subject: [PATCH] Add Z-Image dtype and turbo max-seq runtime overrides --- modules/zimage_poc.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/modules/zimage_poc.py b/modules/zimage_poc.py index 698a641d..43dca269 100644 --- a/modules/zimage_poc.py +++ b/modules/zimage_poc.py @@ -832,10 +832,32 @@ def resolve_zimage_source(name: str, checkpoint_folders: list[str], auto_downloa def _pick_device_and_dtype(): import torch + dtype_override = os.environ.get("FOOOCUS_ZIMAGE_DTYPE", "auto").strip().lower() + + def _resolve_dtype(value: str): + if value in ("bf16", "bfloat16"): + return torch.bfloat16 + if value in ("fp16", "float16", "half"): + return torch.float16 + if value in ("fp32", "float32", "full"): + return torch.float32 + return None + if torch.cuda.is_available(): + requested = _resolve_dtype(dtype_override) + if requested is not None: + # BF16 fallback on GPUs that do not support it. + if requested == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + print("[Z-Image POC] Requested BF16 but CUDA BF16 is unsupported; falling back to FP16.") + return "cuda", torch.float16 + return "cuda", requested if torch.cuda.is_bf16_supported(): return "cuda", torch.bfloat16 return "cuda", torch.float16 + + requested = _resolve_dtype(dtype_override) + if requested is not None: + return "cpu", requested return "cpu", torch.float32 @@ -1203,6 +1225,14 @@ def generate_zimage( pass max_sequence_length = 256 if flavor == "turbo" else 512 + if flavor == "turbo": + env_max_seq = os.environ.get("FOOOCUS_ZIMAGE_TURBO_MAX_SEQ", "").strip() + if env_max_seq: + try: + env_cap = max(64, int(env_max_seq)) + max_sequence_length = min(max_sequence_length, env_cap) + except Exception: + pass forced_max_seq = getattr(pipeline, "_zimage_forced_max_sequence_length", None) if forced_max_seq is not None: max_sequence_length = min(max_sequence_length, int(forced_max_seq)) @@ -1276,7 +1306,7 @@ def generate_zimage( ) print( f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, " - f"max_seq={max_sequence_length}, offload={used_offload}" + f"max_seq={max_sequence_length}, offload={used_offload}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}" ) output = None