diff --git a/modules/async_worker.py b/modules/async_worker.py index 23cd4fe7..77d80e65 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -912,6 +912,16 @@ def worker(): def _truthy_env(name: str, default: str = '0') -> bool: return os.environ.get(name, default).strip().lower() in ('1', 'true', 'yes', 'on') + def _int_env(name: str, default: int, min_value: int = 1, max_value: int = 64) -> int: + raw = os.environ.get(name, str(default)).strip() + try: + value = int(raw) + except Exception: + return default + value = max(min_value, value) + value = min(max_value, value) + return value + def maybe_unload_standard_models_for_zimage(): if not _truthy_env('FOOOCUS_ZIMAGE_UNLOAD_STANDARD_MODELS', '0'): return False @@ -966,8 +976,10 @@ def worker(): progressbar(async_task, current_progress, 'Running Z-Image POC generation ...') release_cache_after_run = _truthy_env('FOOOCUS_ZIMAGE_RELEASE_CACHE_AFTER_RUN', '0') + batch_size = _int_env('FOOOCUS_ZIMAGE_BATCH_SIZE', 1, min_value=1, max_value=8) try: + tasks = [] for i in range(async_task.image_number): if async_task.last_stop is not False: break @@ -987,48 +999,159 @@ def worker(): task_rng, ) - progressbar(async_task, current_progress, f'Generating Z-Image {i + 1}/{async_task.image_number} ...') - try: - pil_image = modules.zimage_poc.generate_zimage( - source_kind=source_kind, - source_path=source_path, - flavor=flavor, - checkpoint_folders=modules.config.paths_checkpoints, + tasks.append( + dict( + index=i, + task_seed=task_seed, prompt=task_prompt, negative_prompt=task_negative_prompt, - width=width, - height=height, - steps=async_task.steps, - guidance_scale=float(async_task.cfg_scale), - seed=task_seed, - shift=float(async_task.adaptive_cfg), - text_encoder_override=async_task.zit_text_encoder, - vae_override=async_task.zit_vae, ) - np_image = np.array(pil_image) + ) + + task_pos = 0 + while task_pos < len(tasks): + if async_task.last_stop is not False: + break + + batch = [tasks[task_pos]] + if batch_size > 1: + anchor_prompt = batch[0]['prompt'] + anchor_negative_prompt = batch[0]['negative_prompt'] + scan = task_pos + 1 + while scan < len(tasks) and len(batch) < batch_size: + cand = tasks[scan] + if cand['prompt'] != anchor_prompt or cand['negative_prompt'] != anchor_negative_prompt: + break + batch.append(cand) + scan += 1 + + def _save_task_result(task_item, np_image): + task = dict( + task_seed=task_item['task_seed'], + log_positive_prompt=task_item['prompt'], + log_negative_prompt=task_item['negative_prompt'], + expansion='', + styles=[], + positive=[task_item['prompt']], + negative=[task_item['negative_prompt']], + ) + progress = int((task_item['index'] + 1) * 100 / float(max(async_task.image_number, 1))) + img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True) + yield_result( + async_task, + img_paths, + progress, + async_task.black_out_nsfw, + False, + do_not_show_finished_images=async_task.disable_intermediate_results, + ) + + try: + if len(batch) > 1: + range_start = batch[0]['index'] + 1 + range_end = batch[-1]['index'] + 1 + progressbar( + async_task, + current_progress, + f'Generating Z-Image {range_start}-{range_end}/{async_task.image_number} (batched) ...', + ) + pil_images = modules.zimage_poc.generate_zimage( + source_kind=source_kind, + source_path=source_path, + flavor=flavor, + checkpoint_folders=modules.config.paths_checkpoints, + prompt=batch[0]['prompt'], + negative_prompt=batch[0]['negative_prompt'], + width=width, + height=height, + steps=async_task.steps, + guidance_scale=float(async_task.cfg_scale), + seed=batch[0]['task_seed'], + seeds=[item['task_seed'] for item in batch], + shift=float(async_task.adaptive_cfg), + text_encoder_override=async_task.zit_text_encoder, + vae_override=async_task.zit_vae, + return_images=True, + ) + if not isinstance(pil_images, list) or len(pil_images) != len(batch): + raise RuntimeError( + f'Batched Z-Image returned {len(pil_images) if isinstance(pil_images, list) else "invalid"} ' + f'images for expected batch size {len(batch)}.' + ) + for task_item, pil_image in zip(batch, pil_images): + _save_task_result(task_item, np.array(pil_image)) + else: + task_item = batch[0] + idx = task_item['index'] + progressbar( + async_task, + current_progress, + f'Generating Z-Image {idx + 1}/{async_task.image_number} ...', + ) + pil_image = modules.zimage_poc.generate_zimage( + source_kind=source_kind, + source_path=source_path, + flavor=flavor, + checkpoint_folders=modules.config.paths_checkpoints, + prompt=task_item['prompt'], + negative_prompt=task_item['negative_prompt'], + width=width, + height=height, + steps=async_task.steps, + guidance_scale=float(async_task.cfg_scale), + seed=task_item['task_seed'], + shift=float(async_task.adaptive_cfg), + text_encoder_override=async_task.zit_text_encoder, + vae_override=async_task.zit_vae, + ) + _save_task_result(task_item, np.array(pil_image)) except ModuleNotFoundError as e: progressbar(async_task, 100, f'Z-Image POC requires missing dependency: {e}') print('[Z-Image POC] Install required dependencies with:') print(' python -m pip install -r requirements_versions.txt') return except Exception as e: - progressbar(async_task, 100, f'Z-Image generation failed: {e}') - print(f'[Z-Image POC] Generation failed: {e}') - return + if len(batch) > 1: + print( + f"[Z-Image POC] Batched generation failed ({e}); retrying batch as single-image runs." + ) + for task_item in batch: + if async_task.last_stop is not False: + break + try: + idx = task_item['index'] + progressbar( + async_task, + current_progress, + f'Generating Z-Image {idx + 1}/{async_task.image_number} ...', + ) + pil_image = modules.zimage_poc.generate_zimage( + source_kind=source_kind, + source_path=source_path, + flavor=flavor, + checkpoint_folders=modules.config.paths_checkpoints, + prompt=task_item['prompt'], + negative_prompt=task_item['negative_prompt'], + width=width, + height=height, + steps=async_task.steps, + guidance_scale=float(async_task.cfg_scale), + seed=task_item['task_seed'], + shift=float(async_task.adaptive_cfg), + text_encoder_override=async_task.zit_text_encoder, + vae_override=async_task.zit_vae, + ) + _save_task_result(task_item, np.array(pil_image)) + except Exception as single_error: + progressbar(async_task, 100, f'Z-Image generation failed: {single_error}') + print(f'[Z-Image POC] Generation failed: {single_error}') + return + else: + progressbar(async_task, 100, f'Z-Image generation failed: {e}') + print(f'[Z-Image POC] Generation failed: {e}') + return - task = dict( - task_seed=task_seed, - log_positive_prompt=task_prompt, - log_negative_prompt=task_negative_prompt, - expansion='', - styles=[], - positive=[task_prompt], - negative=[task_negative_prompt], - ) - progress = int((i + 1) * 100 / float(max(async_task.image_number, 1))) - img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True) - yield_result(async_task, img_paths, progress, async_task.black_out_nsfw, False, - do_not_show_finished_images=async_task.disable_intermediate_results) + task_pos += len(batch) finally: if unloaded_standard_models: try: diff --git a/modules/zimage_poc.py b/modules/zimage_poc.py index 19801dd9..cbea0e4f 100644 --- a/modules/zimage_poc.py +++ b/modules/zimage_poc.py @@ -461,6 +461,17 @@ def _zimage_fp16_quant_accum_mode() -> str: return mode +def _zimage_quant_backend() -> str: + raw = os.environ.get("FOOOCUS_ZIMAGE_QUANT_BACKEND", "runtime").strip().lower() + if raw in ("runtime", "comfy_ops"): + return raw + _warn_once_env( + "FOOOCUS_ZIMAGE_QUANT_BACKEND", + f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_QUANT_BACKEND='{raw}'. Expected: runtime|comfy_ops.", + ) + return "runtime" + + def _zimage_prewarm_enabled() -> bool: return _truthy_env("FOOOCUS_ZIMAGE_PREWARM", "0") @@ -3247,6 +3258,158 @@ def _load_component_override_from_file( "skipped_type_counts": skipped_type_counts, } + def _import_comfy_ops_module(): + import sys + + comfy_root = os.environ.get("FOOOCUS_ZIMAGE_COMFY_ROOT", "").strip() + tried_root = None + initial_error = None + + try: + import comfy.ops as comfy_ops # type: ignore + return comfy_ops, None + except Exception as e: + initial_error = e + + if comfy_root: + tried_root = os.path.abspath(os.path.expanduser(comfy_root)) + if os.path.isdir(tried_root) and tried_root not in sys.path: + sys.path.insert(0, tried_root) + try: + import comfy.ops as comfy_ops # type: ignore + print(f"[Z-Image POC] Loaded Comfy ops from FOOOCUS_ZIMAGE_COMFY_ROOT={tried_root}") + return comfy_ops, None + except Exception as e: + return None, f"{initial_error}; retry_with_root={tried_root}: {e}" + + return None, str(initial_error) + + def _install_comfy_ops_quant_modules(component_module, remapped_sd: dict, compute_dtype) -> dict: + bases = sorted( + { + key[: -len(".comfy_quant")] + for key in remapped_sd.keys() + if key.endswith(".comfy_quant") and f"{key[: -len('.comfy_quant')]}.weight" in remapped_sd + } + ) + if not bases: + return { + "layers": 0, + "replaced": 0, + "skipped": 0, + "float8": 0, + "nvfp4": 0, + "backend": "comfy_ops", + } + + comfy_ops, import_error = _import_comfy_ops_module() + if comfy_ops is None: + return { + "layers": len(bases), + "replaced": 0, + "skipped": len(bases), + "float8": 0, + "nvfp4": 0, + "replaced_bases": set(), + "skipped_unresolved": len(bases), + "skipped_type_counts": {"import_error": len(bases)}, + "backend": "comfy_ops", + "error": f"Unable to import comfy.ops ({import_error})", + } + + try: + mixed_ops = comfy_ops.mixed_precision_ops( + quant_config={}, + compute_dtype=compute_dtype, + full_precision_mm=False, + disabled=[], + ) + linear_cls = getattr(mixed_ops, "Linear", None) + if linear_cls is None: + raise RuntimeError("mixed_precision_ops did not expose Linear class.") + except Exception as e: + return { + "layers": len(bases), + "replaced": 0, + "skipped": len(bases), + "float8": 0, + "nvfp4": 0, + "replaced_bases": set(), + "skipped_unresolved": len(bases), + "skipped_type_counts": {"init_error": len(bases)}, + "backend": "comfy_ops", + "error": f"Failed to initialize Comfy mixed_precision_ops ({e})", + } + + replaced = 0 + skipped = 0 + float8_layers = 0 + nvfp4_layers = 0 + replaced_bases = set() + skipped_type_counts = {} + skipped_unresolved = 0 + + def _mark_skipped_type(name: str): + skipped_type_counts[name] = skipped_type_counts.get(name, 0) + 1 + + for base in bases: + conf = _decode_comfy_quant_entry(remapped_sd.get(f"{base}.comfy_quant")) + fmt = str(conf.get("format", "")).lower() if isinstance(conf, dict) else "" + if fmt in ("float8_e4m3fn", "float8_e5m2"): + float8_layers += 1 + elif fmt == "nvfp4": + nvfp4_layers += 1 + + try: + target = _resolve_module(component_module, base) + except Exception: + skipped += 1 + skipped_unresolved += 1 + continue + + if isinstance(target, linear_cls): + replaced += 1 + replaced_bases.add(base) + continue + if not _is_linear_like_module(target): + skipped += 1 + _mark_skipped_type(type(target).__name__) + continue + + try: + weight = getattr(target, "weight", None) + bias = getattr(target, "bias", None) + in_features = int(getattr(target, "in_features", weight.shape[1])) + out_features = int(getattr(target, "out_features", weight.shape[0])) + target_device = getattr(weight, "device", None) + replacement = linear_cls( + in_features=in_features, + out_features=out_features, + bias=bias is not None, + device=target_device, + dtype=compute_dtype, + ) + except Exception: + skipped += 1 + _mark_skipped_type(type(target).__name__) + continue + + _set_module(component_module, base, replacement) + replaced += 1 + replaced_bases.add(base) + + return { + "layers": len(bases), + "replaced": replaced, + "skipped": skipped, + "float8": float8_layers, + "nvfp4": nvfp4_layers, + "replaced_bases": replaced_bases, + "skipped_unresolved": skipped_unresolved, + "skipped_type_counts": skipped_type_counts, + "backend": "comfy_ops", + } + def _normalize_legacy_scaled_fp8_weights(sd: dict) -> tuple[dict, dict]: converted = dict(sd) migrated = 0 @@ -3481,6 +3644,7 @@ def _load_component_override_from_file( component_name in ("text_encoder", "transformer") and _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_QUANT", "1") ) + runtime_quant_backend = _zimage_quant_backend() runtime_stats = {"layers": 0, "replaced": 0, "skipped": 0, "float8": 0, "nvfp4": 0} if runtime_quant_enabled: remapped_candidate = _remap_state_dict_to_model_keys( @@ -3503,7 +3667,26 @@ def _load_component_override_from_file( f"[Z-Image POC] Native FP8 synth skipped non-linear layers for {component_name}: " f"skipped_non_linear={synth_stats['skipped_non_linear']}." ) - runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate) + if runtime_quant_backend == "comfy_ops": + compute_dtype = getattr(component, "dtype", None) + if compute_dtype not in (torch.float16, torch.bfloat16, torch.float32): + compute_dtype = torch.bfloat16 + runtime_stats = _install_comfy_ops_quant_modules(component, remapped_candidate, compute_dtype) + if runtime_stats.get("error"): + print( + f"[Z-Image POC] Comfy quant backend unavailable for {component_name}: " + f"{runtime_stats['error']}. Falling back to runtime backend." + ) + runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate) + runtime_stats["backend"] = "runtime" + else: + print( + f"[Z-Image POC] Using experimental Comfy mixed_precision_ops backend for {component_name} " + f"(compute_dtype={compute_dtype})." + ) + else: + runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate) + runtime_stats["backend"] = "runtime" if runtime_stats["layers"] > 0 and runtime_stats["replaced"] > 0: replaced_bases = runtime_stats.get("replaced_bases", set()) quant_bases = { @@ -3523,7 +3706,7 @@ def _load_component_override_from_file( print( f"[Z-Image POC] Runtime Comfy quant enabled for {component_name}: " f"layers={runtime_stats['layers']}, fp8={runtime_stats['float8']}, " - f"nvfp4={runtime_stats['nvfp4']}." + f"nvfp4={runtime_stats['nvfp4']}, backend={runtime_stats.get('backend', 'runtime')}." ) else: skipped_types = runtime_stats.get("skipped_type_counts", {}) @@ -3536,7 +3719,7 @@ def _load_component_override_from_file( f"(replaced={runtime_stats['replaced']}/{runtime_stats['layers']}, " f"unmapped_dequantized={runtime_stats.get('unmapped_dequantized', 0)}, " f"skipped_unresolved={runtime_stats.get('skipped_unresolved', 0)}" - f"{skipped_type_summary}, " + f"{skipped_type_summary}, backend={runtime_stats.get('backend', 'runtime')}, " "unmapped layers keep eager load path)." ) @@ -3903,9 +4086,11 @@ def generate_zimage( steps: int, guidance_scale: float, seed: int, + seeds: Optional[list[int]] = None, shift: float = 3.0, text_encoder_override: Optional[str] = None, vae_override: Optional[str] = None, + return_images: bool = False, ): import torch @@ -4022,7 +4207,21 @@ def generate_zimage( _PIPELINE_CACHE.pop(cache_key, None) _clear_prompt_cache_for_pipeline(cache_key) raise - generator = torch.Generator(device=generator_device).manual_seed(seed) + seed_list = [int(seed)] + if seeds: + parsed = [] + for s in seeds: + try: + parsed.append(int(s)) + except Exception: + continue + if parsed: + seed_list = parsed + + if len(seed_list) <= 1: + generator = torch.Generator(device=generator_device).manual_seed(seed_list[0]) + else: + generator = [torch.Generator(device=generator_device).manual_seed(s) for s in seed_list] stage_times["runtime_prep"] = time.perf_counter() - stage_start neg_key = negative_prompt if use_cfg else "" @@ -4071,7 +4270,7 @@ def generate_zimage( num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, - num_images_per_prompt=1, + num_images_per_prompt=max(1, len(seed_list)), cfg_normalization=False, cfg_truncation=1.0, max_sequence_length=max_sequence_length, @@ -4081,7 +4280,7 @@ def generate_zimage( print( f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, " f"size={call_kwargs['width']}x{call_kwargs['height']}, max_seq={max_sequence_length}, offload={used_offload}, " - f"dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}" + f"batch={call_kwargs['num_images_per_prompt']}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}" ) output = None @@ -4112,12 +4311,21 @@ def generate_zimage( output = _run_pipeline_call(pipeline, call_kwargs) if _zimage_black_image_retry_enabled() and not black_retry_used: try: - candidate = output.images[0] + candidates = list(getattr(output, "images", []) or []) except Exception: - candidate = None - if candidate is not None: - is_black, black_info = _is_suspected_black_image(candidate) - if is_black: + candidates = [] + black_entries = [] + for idx, candidate in enumerate(candidates): + try: + is_black, black_info = _is_suspected_black_image(candidate) + except Exception: + is_black, black_info = False, None + if is_black and black_info is not None: + black_entries.append((idx, black_info)) + if black_entries: + first_black_idx, first_black_info = black_entries[0] + is_batch_black = len(candidates) > 1 and len(black_entries) == len(candidates) + if len(candidates) == 1 or is_batch_black: black_retry_used = True strategy = str(getattr(pipeline, "_zimage_xformers_strategy", "unknown")) transformer = getattr(pipeline, "transformer", None) @@ -4127,7 +4335,8 @@ def generate_zimage( if strict_fp16: print( f"[Z-Image POC] Suspected black output detected " - f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, " + f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, " + f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, " f"attn={strategy}, dtype={transformer_dtype}). Strict FP16 mode enabled; no fallback." ) raise RuntimeError( @@ -4135,7 +4344,8 @@ def generate_zimage( ) print( f"[Z-Image POC] Suspected black output detected " - f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, " + f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, " + f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, " f"attn={strategy}, dtype={transformer_dtype}). Retrying once with safer runtime." ) @@ -4177,7 +4387,12 @@ def generate_zimage( if changed: if generator_device == "cuda": _cleanup_memory(cuda=True, aggressive=True) - call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed) + if len(seed_list) <= 1: + call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed_list[0]) + else: + call_kwargs["generator"] = [ + torch.Generator(device=generator_device).manual_seed(s) for s in seed_list + ] original_output = output try: output = _run_pipeline_call(pipeline, call_kwargs) @@ -4187,9 +4402,15 @@ def generate_zimage( f"[Z-Image POC] Black-image retry failed ({retry_error}); keeping original output." ) try: - retry_image = output.images[0] - retry_black, retry_info = _is_suspected_black_image(retry_image) - if retry_black: + retry_candidates = list(getattr(output, "images", []) or []) + retry_black_any = False + retry_info = None + for retry_image in retry_candidates: + retry_black, retry_info = _is_suspected_black_image(retry_image) + if retry_black: + retry_black_any = True + break + if retry_black_any and retry_info is not None: print( f"[Z-Image POC] Black-image retry remained near-black " f"(mean={retry_info['mean']:.2f}, max={retry_info['max']:.0f})." @@ -4202,6 +4423,12 @@ def generate_zimage( pass else: print("[Z-Image POC] No safe retry knobs available; keeping original output.") + elif black_entries: + # For batches, retry only if every output is black. Mixed batches are preserved. + print( + f"[Z-Image POC] Batch output has {len(black_entries)}/{len(candidates)} near-black images; " + "keeping batch output." + ) break except RuntimeError as e: msg = str(e).lower() @@ -4272,10 +4499,14 @@ def generate_zimage( stage_times["pipeline_call"] = time.perf_counter() - call_start stage_start = time.perf_counter() - image = output.images[0] + images = list(getattr(output, "images", []) or []) del output stage_times["extract_image"] = time.perf_counter() - stage_start - return image + if return_images: + return images + if not images: + raise RuntimeError("Z-Image pipeline returned no images.") + return images[0] except Exception as e: error_name = type(e).__name__ if "pipeline_call" not in stage_times: