Add Z-Image dtype and turbo max-seq runtime overrides

pull/4166/head
Developer 2026-02-17 20:30:35 +02:00
parent 19e52ed092
commit 8af0646df3
1 changed files with 31 additions and 1 deletions

View File

@ -832,10 +832,32 @@ def resolve_zimage_source(name: str, checkpoint_folders: list[str], auto_downloa
def _pick_device_and_dtype():
import torch
dtype_override = os.environ.get("FOOOCUS_ZIMAGE_DTYPE", "auto").strip().lower()
def _resolve_dtype(value: str):
if value in ("bf16", "bfloat16"):
return torch.bfloat16
if value in ("fp16", "float16", "half"):
return torch.float16
if value in ("fp32", "float32", "full"):
return torch.float32
return None
if torch.cuda.is_available():
requested = _resolve_dtype(dtype_override)
if requested is not None:
# BF16 fallback on GPUs that do not support it.
if requested == torch.bfloat16 and not torch.cuda.is_bf16_supported():
print("[Z-Image POC] Requested BF16 but CUDA BF16 is unsupported; falling back to FP16.")
return "cuda", torch.float16
return "cuda", requested
if torch.cuda.is_bf16_supported():
return "cuda", torch.bfloat16
return "cuda", torch.float16
requested = _resolve_dtype(dtype_override)
if requested is not None:
return "cpu", requested
return "cpu", torch.float32
@ -1203,6 +1225,14 @@ def generate_zimage(
pass
max_sequence_length = 256 if flavor == "turbo" else 512
if flavor == "turbo":
env_max_seq = os.environ.get("FOOOCUS_ZIMAGE_TURBO_MAX_SEQ", "").strip()
if env_max_seq:
try:
env_cap = max(64, int(env_max_seq))
max_sequence_length = min(max_sequence_length, env_cap)
except Exception:
pass
forced_max_seq = getattr(pipeline, "_zimage_forced_max_sequence_length", None)
if forced_max_seq is not None:
max_sequence_length = min(max_sequence_length, int(forced_max_seq))
@ -1276,7 +1306,7 @@ def generate_zimage(
)
print(
f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, "
f"max_seq={max_sequence_length}, offload={used_offload}"
f"max_seq={max_sequence_length}, offload={used_offload}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}"
)
output = None