mirror of https://github.com/lllyasviel/Fooocus
Add Z-Image dtype and turbo max-seq runtime overrides
parent
19e52ed092
commit
8af0646df3
|
|
@ -832,10 +832,32 @@ def resolve_zimage_source(name: str, checkpoint_folders: list[str], auto_downloa
|
|||
def _pick_device_and_dtype():
|
||||
import torch
|
||||
|
||||
dtype_override = os.environ.get("FOOOCUS_ZIMAGE_DTYPE", "auto").strip().lower()
|
||||
|
||||
def _resolve_dtype(value: str):
|
||||
if value in ("bf16", "bfloat16"):
|
||||
return torch.bfloat16
|
||||
if value in ("fp16", "float16", "half"):
|
||||
return torch.float16
|
||||
if value in ("fp32", "float32", "full"):
|
||||
return torch.float32
|
||||
return None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
requested = _resolve_dtype(dtype_override)
|
||||
if requested is not None:
|
||||
# BF16 fallback on GPUs that do not support it.
|
||||
if requested == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
print("[Z-Image POC] Requested BF16 but CUDA BF16 is unsupported; falling back to FP16.")
|
||||
return "cuda", torch.float16
|
||||
return "cuda", requested
|
||||
if torch.cuda.is_bf16_supported():
|
||||
return "cuda", torch.bfloat16
|
||||
return "cuda", torch.float16
|
||||
|
||||
requested = _resolve_dtype(dtype_override)
|
||||
if requested is not None:
|
||||
return "cpu", requested
|
||||
return "cpu", torch.float32
|
||||
|
||||
|
||||
|
|
@ -1203,6 +1225,14 @@ def generate_zimage(
|
|||
pass
|
||||
|
||||
max_sequence_length = 256 if flavor == "turbo" else 512
|
||||
if flavor == "turbo":
|
||||
env_max_seq = os.environ.get("FOOOCUS_ZIMAGE_TURBO_MAX_SEQ", "").strip()
|
||||
if env_max_seq:
|
||||
try:
|
||||
env_cap = max(64, int(env_max_seq))
|
||||
max_sequence_length = min(max_sequence_length, env_cap)
|
||||
except Exception:
|
||||
pass
|
||||
forced_max_seq = getattr(pipeline, "_zimage_forced_max_sequence_length", None)
|
||||
if forced_max_seq is not None:
|
||||
max_sequence_length = min(max_sequence_length, int(forced_max_seq))
|
||||
|
|
@ -1276,7 +1306,7 @@ def generate_zimage(
|
|||
)
|
||||
print(
|
||||
f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, "
|
||||
f"max_seq={max_sequence_length}, offload={used_offload}"
|
||||
f"max_seq={max_sequence_length}, offload={used_offload}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}"
|
||||
)
|
||||
|
||||
output = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue