alternative path

pull/4166/head
Developer 2026-02-19 14:52:29 +02:00
parent 3875e5ad4f
commit 0e1bd3e1de
2 changed files with 405 additions and 51 deletions

View File

@ -912,6 +912,16 @@ def worker():
def _truthy_env(name: str, default: str = '0') -> bool:
return os.environ.get(name, default).strip().lower() in ('1', 'true', 'yes', 'on')
def _int_env(name: str, default: int, min_value: int = 1, max_value: int = 64) -> int:
raw = os.environ.get(name, str(default)).strip()
try:
value = int(raw)
except Exception:
return default
value = max(min_value, value)
value = min(max_value, value)
return value
def maybe_unload_standard_models_for_zimage():
if not _truthy_env('FOOOCUS_ZIMAGE_UNLOAD_STANDARD_MODELS', '0'):
return False
@ -966,8 +976,10 @@ def worker():
progressbar(async_task, current_progress, 'Running Z-Image POC generation ...')
release_cache_after_run = _truthy_env('FOOOCUS_ZIMAGE_RELEASE_CACHE_AFTER_RUN', '0')
batch_size = _int_env('FOOOCUS_ZIMAGE_BATCH_SIZE', 1, min_value=1, max_value=8)
try:
tasks = []
for i in range(async_task.image_number):
if async_task.last_stop is not False:
break
@ -987,48 +999,159 @@ def worker():
task_rng,
)
progressbar(async_task, current_progress, f'Generating Z-Image {i + 1}/{async_task.image_number} ...')
try:
pil_image = modules.zimage_poc.generate_zimage(
source_kind=source_kind,
source_path=source_path,
flavor=flavor,
checkpoint_folders=modules.config.paths_checkpoints,
tasks.append(
dict(
index=i,
task_seed=task_seed,
prompt=task_prompt,
negative_prompt=task_negative_prompt,
width=width,
height=height,
steps=async_task.steps,
guidance_scale=float(async_task.cfg_scale),
seed=task_seed,
shift=float(async_task.adaptive_cfg),
text_encoder_override=async_task.zit_text_encoder,
vae_override=async_task.zit_vae,
)
np_image = np.array(pil_image)
)
task_pos = 0
while task_pos < len(tasks):
if async_task.last_stop is not False:
break
batch = [tasks[task_pos]]
if batch_size > 1:
anchor_prompt = batch[0]['prompt']
anchor_negative_prompt = batch[0]['negative_prompt']
scan = task_pos + 1
while scan < len(tasks) and len(batch) < batch_size:
cand = tasks[scan]
if cand['prompt'] != anchor_prompt or cand['negative_prompt'] != anchor_negative_prompt:
break
batch.append(cand)
scan += 1
def _save_task_result(task_item, np_image):
task = dict(
task_seed=task_item['task_seed'],
log_positive_prompt=task_item['prompt'],
log_negative_prompt=task_item['negative_prompt'],
expansion='',
styles=[],
positive=[task_item['prompt']],
negative=[task_item['negative_prompt']],
)
progress = int((task_item['index'] + 1) * 100 / float(max(async_task.image_number, 1)))
img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True)
yield_result(
async_task,
img_paths,
progress,
async_task.black_out_nsfw,
False,
do_not_show_finished_images=async_task.disable_intermediate_results,
)
try:
if len(batch) > 1:
range_start = batch[0]['index'] + 1
range_end = batch[-1]['index'] + 1
progressbar(
async_task,
current_progress,
f'Generating Z-Image {range_start}-{range_end}/{async_task.image_number} (batched) ...',
)
pil_images = modules.zimage_poc.generate_zimage(
source_kind=source_kind,
source_path=source_path,
flavor=flavor,
checkpoint_folders=modules.config.paths_checkpoints,
prompt=batch[0]['prompt'],
negative_prompt=batch[0]['negative_prompt'],
width=width,
height=height,
steps=async_task.steps,
guidance_scale=float(async_task.cfg_scale),
seed=batch[0]['task_seed'],
seeds=[item['task_seed'] for item in batch],
shift=float(async_task.adaptive_cfg),
text_encoder_override=async_task.zit_text_encoder,
vae_override=async_task.zit_vae,
return_images=True,
)
if not isinstance(pil_images, list) or len(pil_images) != len(batch):
raise RuntimeError(
f'Batched Z-Image returned {len(pil_images) if isinstance(pil_images, list) else "invalid"} '
f'images for expected batch size {len(batch)}.'
)
for task_item, pil_image in zip(batch, pil_images):
_save_task_result(task_item, np.array(pil_image))
else:
task_item = batch[0]
idx = task_item['index']
progressbar(
async_task,
current_progress,
f'Generating Z-Image {idx + 1}/{async_task.image_number} ...',
)
pil_image = modules.zimage_poc.generate_zimage(
source_kind=source_kind,
source_path=source_path,
flavor=flavor,
checkpoint_folders=modules.config.paths_checkpoints,
prompt=task_item['prompt'],
negative_prompt=task_item['negative_prompt'],
width=width,
height=height,
steps=async_task.steps,
guidance_scale=float(async_task.cfg_scale),
seed=task_item['task_seed'],
shift=float(async_task.adaptive_cfg),
text_encoder_override=async_task.zit_text_encoder,
vae_override=async_task.zit_vae,
)
_save_task_result(task_item, np.array(pil_image))
except ModuleNotFoundError as e:
progressbar(async_task, 100, f'Z-Image POC requires missing dependency: {e}')
print('[Z-Image POC] Install required dependencies with:')
print(' python -m pip install -r requirements_versions.txt')
return
except Exception as e:
progressbar(async_task, 100, f'Z-Image generation failed: {e}')
print(f'[Z-Image POC] Generation failed: {e}')
return
if len(batch) > 1:
print(
f"[Z-Image POC] Batched generation failed ({e}); retrying batch as single-image runs."
)
for task_item in batch:
if async_task.last_stop is not False:
break
try:
idx = task_item['index']
progressbar(
async_task,
current_progress,
f'Generating Z-Image {idx + 1}/{async_task.image_number} ...',
)
pil_image = modules.zimage_poc.generate_zimage(
source_kind=source_kind,
source_path=source_path,
flavor=flavor,
checkpoint_folders=modules.config.paths_checkpoints,
prompt=task_item['prompt'],
negative_prompt=task_item['negative_prompt'],
width=width,
height=height,
steps=async_task.steps,
guidance_scale=float(async_task.cfg_scale),
seed=task_item['task_seed'],
shift=float(async_task.adaptive_cfg),
text_encoder_override=async_task.zit_text_encoder,
vae_override=async_task.zit_vae,
)
_save_task_result(task_item, np.array(pil_image))
except Exception as single_error:
progressbar(async_task, 100, f'Z-Image generation failed: {single_error}')
print(f'[Z-Image POC] Generation failed: {single_error}')
return
else:
progressbar(async_task, 100, f'Z-Image generation failed: {e}')
print(f'[Z-Image POC] Generation failed: {e}')
return
task = dict(
task_seed=task_seed,
log_positive_prompt=task_prompt,
log_negative_prompt=task_negative_prompt,
expansion='',
styles=[],
positive=[task_prompt],
negative=[task_negative_prompt],
)
progress = int((i + 1) * 100 / float(max(async_task.image_number, 1)))
img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True)
yield_result(async_task, img_paths, progress, async_task.black_out_nsfw, False,
do_not_show_finished_images=async_task.disable_intermediate_results)
task_pos += len(batch)
finally:
if unloaded_standard_models:
try:

View File

@ -461,6 +461,17 @@ def _zimage_fp16_quant_accum_mode() -> str:
return mode
def _zimage_quant_backend() -> str:
raw = os.environ.get("FOOOCUS_ZIMAGE_QUANT_BACKEND", "runtime").strip().lower()
if raw in ("runtime", "comfy_ops"):
return raw
_warn_once_env(
"FOOOCUS_ZIMAGE_QUANT_BACKEND",
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_QUANT_BACKEND='{raw}'. Expected: runtime|comfy_ops.",
)
return "runtime"
def _zimage_prewarm_enabled() -> bool:
return _truthy_env("FOOOCUS_ZIMAGE_PREWARM", "0")
@ -3247,6 +3258,158 @@ def _load_component_override_from_file(
"skipped_type_counts": skipped_type_counts,
}
def _import_comfy_ops_module():
import sys
comfy_root = os.environ.get("FOOOCUS_ZIMAGE_COMFY_ROOT", "").strip()
tried_root = None
initial_error = None
try:
import comfy.ops as comfy_ops # type: ignore
return comfy_ops, None
except Exception as e:
initial_error = e
if comfy_root:
tried_root = os.path.abspath(os.path.expanduser(comfy_root))
if os.path.isdir(tried_root) and tried_root not in sys.path:
sys.path.insert(0, tried_root)
try:
import comfy.ops as comfy_ops # type: ignore
print(f"[Z-Image POC] Loaded Comfy ops from FOOOCUS_ZIMAGE_COMFY_ROOT={tried_root}")
return comfy_ops, None
except Exception as e:
return None, f"{initial_error}; retry_with_root={tried_root}: {e}"
return None, str(initial_error)
def _install_comfy_ops_quant_modules(component_module, remapped_sd: dict, compute_dtype) -> dict:
bases = sorted(
{
key[: -len(".comfy_quant")]
for key in remapped_sd.keys()
if key.endswith(".comfy_quant") and f"{key[: -len('.comfy_quant')]}.weight" in remapped_sd
}
)
if not bases:
return {
"layers": 0,
"replaced": 0,
"skipped": 0,
"float8": 0,
"nvfp4": 0,
"backend": "comfy_ops",
}
comfy_ops, import_error = _import_comfy_ops_module()
if comfy_ops is None:
return {
"layers": len(bases),
"replaced": 0,
"skipped": len(bases),
"float8": 0,
"nvfp4": 0,
"replaced_bases": set(),
"skipped_unresolved": len(bases),
"skipped_type_counts": {"import_error": len(bases)},
"backend": "comfy_ops",
"error": f"Unable to import comfy.ops ({import_error})",
}
try:
mixed_ops = comfy_ops.mixed_precision_ops(
quant_config={},
compute_dtype=compute_dtype,
full_precision_mm=False,
disabled=[],
)
linear_cls = getattr(mixed_ops, "Linear", None)
if linear_cls is None:
raise RuntimeError("mixed_precision_ops did not expose Linear class.")
except Exception as e:
return {
"layers": len(bases),
"replaced": 0,
"skipped": len(bases),
"float8": 0,
"nvfp4": 0,
"replaced_bases": set(),
"skipped_unresolved": len(bases),
"skipped_type_counts": {"init_error": len(bases)},
"backend": "comfy_ops",
"error": f"Failed to initialize Comfy mixed_precision_ops ({e})",
}
replaced = 0
skipped = 0
float8_layers = 0
nvfp4_layers = 0
replaced_bases = set()
skipped_type_counts = {}
skipped_unresolved = 0
def _mark_skipped_type(name: str):
skipped_type_counts[name] = skipped_type_counts.get(name, 0) + 1
for base in bases:
conf = _decode_comfy_quant_entry(remapped_sd.get(f"{base}.comfy_quant"))
fmt = str(conf.get("format", "")).lower() if isinstance(conf, dict) else ""
if fmt in ("float8_e4m3fn", "float8_e5m2"):
float8_layers += 1
elif fmt == "nvfp4":
nvfp4_layers += 1
try:
target = _resolve_module(component_module, base)
except Exception:
skipped += 1
skipped_unresolved += 1
continue
if isinstance(target, linear_cls):
replaced += 1
replaced_bases.add(base)
continue
if not _is_linear_like_module(target):
skipped += 1
_mark_skipped_type(type(target).__name__)
continue
try:
weight = getattr(target, "weight", None)
bias = getattr(target, "bias", None)
in_features = int(getattr(target, "in_features", weight.shape[1]))
out_features = int(getattr(target, "out_features", weight.shape[0]))
target_device = getattr(weight, "device", None)
replacement = linear_cls(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=target_device,
dtype=compute_dtype,
)
except Exception:
skipped += 1
_mark_skipped_type(type(target).__name__)
continue
_set_module(component_module, base, replacement)
replaced += 1
replaced_bases.add(base)
return {
"layers": len(bases),
"replaced": replaced,
"skipped": skipped,
"float8": float8_layers,
"nvfp4": nvfp4_layers,
"replaced_bases": replaced_bases,
"skipped_unresolved": skipped_unresolved,
"skipped_type_counts": skipped_type_counts,
"backend": "comfy_ops",
}
def _normalize_legacy_scaled_fp8_weights(sd: dict) -> tuple[dict, dict]:
converted = dict(sd)
migrated = 0
@ -3481,6 +3644,7 @@ def _load_component_override_from_file(
component_name in ("text_encoder", "transformer")
and _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_QUANT", "1")
)
runtime_quant_backend = _zimage_quant_backend()
runtime_stats = {"layers": 0, "replaced": 0, "skipped": 0, "float8": 0, "nvfp4": 0}
if runtime_quant_enabled:
remapped_candidate = _remap_state_dict_to_model_keys(
@ -3503,7 +3667,26 @@ def _load_component_override_from_file(
f"[Z-Image POC] Native FP8 synth skipped non-linear layers for {component_name}: "
f"skipped_non_linear={synth_stats['skipped_non_linear']}."
)
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
if runtime_quant_backend == "comfy_ops":
compute_dtype = getattr(component, "dtype", None)
if compute_dtype not in (torch.float16, torch.bfloat16, torch.float32):
compute_dtype = torch.bfloat16
runtime_stats = _install_comfy_ops_quant_modules(component, remapped_candidate, compute_dtype)
if runtime_stats.get("error"):
print(
f"[Z-Image POC] Comfy quant backend unavailable for {component_name}: "
f"{runtime_stats['error']}. Falling back to runtime backend."
)
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
runtime_stats["backend"] = "runtime"
else:
print(
f"[Z-Image POC] Using experimental Comfy mixed_precision_ops backend for {component_name} "
f"(compute_dtype={compute_dtype})."
)
else:
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
runtime_stats["backend"] = "runtime"
if runtime_stats["layers"] > 0 and runtime_stats["replaced"] > 0:
replaced_bases = runtime_stats.get("replaced_bases", set())
quant_bases = {
@ -3523,7 +3706,7 @@ def _load_component_override_from_file(
print(
f"[Z-Image POC] Runtime Comfy quant enabled for {component_name}: "
f"layers={runtime_stats['layers']}, fp8={runtime_stats['float8']}, "
f"nvfp4={runtime_stats['nvfp4']}."
f"nvfp4={runtime_stats['nvfp4']}, backend={runtime_stats.get('backend', 'runtime')}."
)
else:
skipped_types = runtime_stats.get("skipped_type_counts", {})
@ -3536,7 +3719,7 @@ def _load_component_override_from_file(
f"(replaced={runtime_stats['replaced']}/{runtime_stats['layers']}, "
f"unmapped_dequantized={runtime_stats.get('unmapped_dequantized', 0)}, "
f"skipped_unresolved={runtime_stats.get('skipped_unresolved', 0)}"
f"{skipped_type_summary}, "
f"{skipped_type_summary}, backend={runtime_stats.get('backend', 'runtime')}, "
"unmapped layers keep eager load path)."
)
@ -3903,9 +4086,11 @@ def generate_zimage(
steps: int,
guidance_scale: float,
seed: int,
seeds: Optional[list[int]] = None,
shift: float = 3.0,
text_encoder_override: Optional[str] = None,
vae_override: Optional[str] = None,
return_images: bool = False,
):
import torch
@ -4022,7 +4207,21 @@ def generate_zimage(
_PIPELINE_CACHE.pop(cache_key, None)
_clear_prompt_cache_for_pipeline(cache_key)
raise
generator = torch.Generator(device=generator_device).manual_seed(seed)
seed_list = [int(seed)]
if seeds:
parsed = []
for s in seeds:
try:
parsed.append(int(s))
except Exception:
continue
if parsed:
seed_list = parsed
if len(seed_list) <= 1:
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
else:
generator = [torch.Generator(device=generator_device).manual_seed(s) for s in seed_list]
stage_times["runtime_prep"] = time.perf_counter() - stage_start
neg_key = negative_prompt if use_cfg else ""
@ -4071,7 +4270,7 @@ def generate_zimage(
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=1,
num_images_per_prompt=max(1, len(seed_list)),
cfg_normalization=False,
cfg_truncation=1.0,
max_sequence_length=max_sequence_length,
@ -4081,7 +4280,7 @@ def generate_zimage(
print(
f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, "
f"size={call_kwargs['width']}x{call_kwargs['height']}, max_seq={max_sequence_length}, offload={used_offload}, "
f"dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}"
f"batch={call_kwargs['num_images_per_prompt']}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}"
)
output = None
@ -4112,12 +4311,21 @@ def generate_zimage(
output = _run_pipeline_call(pipeline, call_kwargs)
if _zimage_black_image_retry_enabled() and not black_retry_used:
try:
candidate = output.images[0]
candidates = list(getattr(output, "images", []) or [])
except Exception:
candidate = None
if candidate is not None:
is_black, black_info = _is_suspected_black_image(candidate)
if is_black:
candidates = []
black_entries = []
for idx, candidate in enumerate(candidates):
try:
is_black, black_info = _is_suspected_black_image(candidate)
except Exception:
is_black, black_info = False, None
if is_black and black_info is not None:
black_entries.append((idx, black_info))
if black_entries:
first_black_idx, first_black_info = black_entries[0]
is_batch_black = len(candidates) > 1 and len(black_entries) == len(candidates)
if len(candidates) == 1 or is_batch_black:
black_retry_used = True
strategy = str(getattr(pipeline, "_zimage_xformers_strategy", "unknown"))
transformer = getattr(pipeline, "transformer", None)
@ -4127,7 +4335,8 @@ def generate_zimage(
if strict_fp16:
print(
f"[Z-Image POC] Suspected black output detected "
f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, "
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
f"attn={strategy}, dtype={transformer_dtype}). Strict FP16 mode enabled; no fallback."
)
raise RuntimeError(
@ -4135,7 +4344,8 @@ def generate_zimage(
)
print(
f"[Z-Image POC] Suspected black output detected "
f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, "
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
f"attn={strategy}, dtype={transformer_dtype}). Retrying once with safer runtime."
)
@ -4177,7 +4387,12 @@ def generate_zimage(
if changed:
if generator_device == "cuda":
_cleanup_memory(cuda=True, aggressive=True)
call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed)
if len(seed_list) <= 1:
call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed_list[0])
else:
call_kwargs["generator"] = [
torch.Generator(device=generator_device).manual_seed(s) for s in seed_list
]
original_output = output
try:
output = _run_pipeline_call(pipeline, call_kwargs)
@ -4187,9 +4402,15 @@ def generate_zimage(
f"[Z-Image POC] Black-image retry failed ({retry_error}); keeping original output."
)
try:
retry_image = output.images[0]
retry_black, retry_info = _is_suspected_black_image(retry_image)
if retry_black:
retry_candidates = list(getattr(output, "images", []) or [])
retry_black_any = False
retry_info = None
for retry_image in retry_candidates:
retry_black, retry_info = _is_suspected_black_image(retry_image)
if retry_black:
retry_black_any = True
break
if retry_black_any and retry_info is not None:
print(
f"[Z-Image POC] Black-image retry remained near-black "
f"(mean={retry_info['mean']:.2f}, max={retry_info['max']:.0f})."
@ -4202,6 +4423,12 @@ def generate_zimage(
pass
else:
print("[Z-Image POC] No safe retry knobs available; keeping original output.")
elif black_entries:
# For batches, retry only if every output is black. Mixed batches are preserved.
print(
f"[Z-Image POC] Batch output has {len(black_entries)}/{len(candidates)} near-black images; "
"keeping batch output."
)
break
except RuntimeError as e:
msg = str(e).lower()
@ -4272,10 +4499,14 @@ def generate_zimage(
stage_times["pipeline_call"] = time.perf_counter() - call_start
stage_start = time.perf_counter()
image = output.images[0]
images = list(getattr(output, "images", []) or [])
del output
stage_times["extract_image"] = time.perf_counter() - stage_start
return image
if return_images:
return images
if not images:
raise RuntimeError("Z-Image pipeline returned no images.")
return images[0]
except Exception as e:
error_name = type(e).__name__
if "pipeline_call" not in stage_times: