mirror of https://github.com/lllyasviel/Fooocus
alternative path
parent
3875e5ad4f
commit
0e1bd3e1de
|
|
@ -912,6 +912,16 @@ def worker():
|
|||
def _truthy_env(name: str, default: str = '0') -> bool:
|
||||
return os.environ.get(name, default).strip().lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
def _int_env(name: str, default: int, min_value: int = 1, max_value: int = 64) -> int:
|
||||
raw = os.environ.get(name, str(default)).strip()
|
||||
try:
|
||||
value = int(raw)
|
||||
except Exception:
|
||||
return default
|
||||
value = max(min_value, value)
|
||||
value = min(max_value, value)
|
||||
return value
|
||||
|
||||
def maybe_unload_standard_models_for_zimage():
|
||||
if not _truthy_env('FOOOCUS_ZIMAGE_UNLOAD_STANDARD_MODELS', '0'):
|
||||
return False
|
||||
|
|
@ -966,8 +976,10 @@ def worker():
|
|||
progressbar(async_task, current_progress, 'Running Z-Image POC generation ...')
|
||||
|
||||
release_cache_after_run = _truthy_env('FOOOCUS_ZIMAGE_RELEASE_CACHE_AFTER_RUN', '0')
|
||||
batch_size = _int_env('FOOOCUS_ZIMAGE_BATCH_SIZE', 1, min_value=1, max_value=8)
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for i in range(async_task.image_number):
|
||||
if async_task.last_stop is not False:
|
||||
break
|
||||
|
|
@ -987,48 +999,159 @@ def worker():
|
|||
task_rng,
|
||||
)
|
||||
|
||||
progressbar(async_task, current_progress, f'Generating Z-Image {i + 1}/{async_task.image_number} ...')
|
||||
try:
|
||||
pil_image = modules.zimage_poc.generate_zimage(
|
||||
source_kind=source_kind,
|
||||
source_path=source_path,
|
||||
flavor=flavor,
|
||||
checkpoint_folders=modules.config.paths_checkpoints,
|
||||
tasks.append(
|
||||
dict(
|
||||
index=i,
|
||||
task_seed=task_seed,
|
||||
prompt=task_prompt,
|
||||
negative_prompt=task_negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=async_task.steps,
|
||||
guidance_scale=float(async_task.cfg_scale),
|
||||
seed=task_seed,
|
||||
shift=float(async_task.adaptive_cfg),
|
||||
text_encoder_override=async_task.zit_text_encoder,
|
||||
vae_override=async_task.zit_vae,
|
||||
)
|
||||
np_image = np.array(pil_image)
|
||||
)
|
||||
|
||||
task_pos = 0
|
||||
while task_pos < len(tasks):
|
||||
if async_task.last_stop is not False:
|
||||
break
|
||||
|
||||
batch = [tasks[task_pos]]
|
||||
if batch_size > 1:
|
||||
anchor_prompt = batch[0]['prompt']
|
||||
anchor_negative_prompt = batch[0]['negative_prompt']
|
||||
scan = task_pos + 1
|
||||
while scan < len(tasks) and len(batch) < batch_size:
|
||||
cand = tasks[scan]
|
||||
if cand['prompt'] != anchor_prompt or cand['negative_prompt'] != anchor_negative_prompt:
|
||||
break
|
||||
batch.append(cand)
|
||||
scan += 1
|
||||
|
||||
def _save_task_result(task_item, np_image):
|
||||
task = dict(
|
||||
task_seed=task_item['task_seed'],
|
||||
log_positive_prompt=task_item['prompt'],
|
||||
log_negative_prompt=task_item['negative_prompt'],
|
||||
expansion='',
|
||||
styles=[],
|
||||
positive=[task_item['prompt']],
|
||||
negative=[task_item['negative_prompt']],
|
||||
)
|
||||
progress = int((task_item['index'] + 1) * 100 / float(max(async_task.image_number, 1)))
|
||||
img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True)
|
||||
yield_result(
|
||||
async_task,
|
||||
img_paths,
|
||||
progress,
|
||||
async_task.black_out_nsfw,
|
||||
False,
|
||||
do_not_show_finished_images=async_task.disable_intermediate_results,
|
||||
)
|
||||
|
||||
try:
|
||||
if len(batch) > 1:
|
||||
range_start = batch[0]['index'] + 1
|
||||
range_end = batch[-1]['index'] + 1
|
||||
progressbar(
|
||||
async_task,
|
||||
current_progress,
|
||||
f'Generating Z-Image {range_start}-{range_end}/{async_task.image_number} (batched) ...',
|
||||
)
|
||||
pil_images = modules.zimage_poc.generate_zimage(
|
||||
source_kind=source_kind,
|
||||
source_path=source_path,
|
||||
flavor=flavor,
|
||||
checkpoint_folders=modules.config.paths_checkpoints,
|
||||
prompt=batch[0]['prompt'],
|
||||
negative_prompt=batch[0]['negative_prompt'],
|
||||
width=width,
|
||||
height=height,
|
||||
steps=async_task.steps,
|
||||
guidance_scale=float(async_task.cfg_scale),
|
||||
seed=batch[0]['task_seed'],
|
||||
seeds=[item['task_seed'] for item in batch],
|
||||
shift=float(async_task.adaptive_cfg),
|
||||
text_encoder_override=async_task.zit_text_encoder,
|
||||
vae_override=async_task.zit_vae,
|
||||
return_images=True,
|
||||
)
|
||||
if not isinstance(pil_images, list) or len(pil_images) != len(batch):
|
||||
raise RuntimeError(
|
||||
f'Batched Z-Image returned {len(pil_images) if isinstance(pil_images, list) else "invalid"} '
|
||||
f'images for expected batch size {len(batch)}.'
|
||||
)
|
||||
for task_item, pil_image in zip(batch, pil_images):
|
||||
_save_task_result(task_item, np.array(pil_image))
|
||||
else:
|
||||
task_item = batch[0]
|
||||
idx = task_item['index']
|
||||
progressbar(
|
||||
async_task,
|
||||
current_progress,
|
||||
f'Generating Z-Image {idx + 1}/{async_task.image_number} ...',
|
||||
)
|
||||
pil_image = modules.zimage_poc.generate_zimage(
|
||||
source_kind=source_kind,
|
||||
source_path=source_path,
|
||||
flavor=flavor,
|
||||
checkpoint_folders=modules.config.paths_checkpoints,
|
||||
prompt=task_item['prompt'],
|
||||
negative_prompt=task_item['negative_prompt'],
|
||||
width=width,
|
||||
height=height,
|
||||
steps=async_task.steps,
|
||||
guidance_scale=float(async_task.cfg_scale),
|
||||
seed=task_item['task_seed'],
|
||||
shift=float(async_task.adaptive_cfg),
|
||||
text_encoder_override=async_task.zit_text_encoder,
|
||||
vae_override=async_task.zit_vae,
|
||||
)
|
||||
_save_task_result(task_item, np.array(pil_image))
|
||||
except ModuleNotFoundError as e:
|
||||
progressbar(async_task, 100, f'Z-Image POC requires missing dependency: {e}')
|
||||
print('[Z-Image POC] Install required dependencies with:')
|
||||
print(' python -m pip install -r requirements_versions.txt')
|
||||
return
|
||||
except Exception as e:
|
||||
progressbar(async_task, 100, f'Z-Image generation failed: {e}')
|
||||
print(f'[Z-Image POC] Generation failed: {e}')
|
||||
return
|
||||
if len(batch) > 1:
|
||||
print(
|
||||
f"[Z-Image POC] Batched generation failed ({e}); retrying batch as single-image runs."
|
||||
)
|
||||
for task_item in batch:
|
||||
if async_task.last_stop is not False:
|
||||
break
|
||||
try:
|
||||
idx = task_item['index']
|
||||
progressbar(
|
||||
async_task,
|
||||
current_progress,
|
||||
f'Generating Z-Image {idx + 1}/{async_task.image_number} ...',
|
||||
)
|
||||
pil_image = modules.zimage_poc.generate_zimage(
|
||||
source_kind=source_kind,
|
||||
source_path=source_path,
|
||||
flavor=flavor,
|
||||
checkpoint_folders=modules.config.paths_checkpoints,
|
||||
prompt=task_item['prompt'],
|
||||
negative_prompt=task_item['negative_prompt'],
|
||||
width=width,
|
||||
height=height,
|
||||
steps=async_task.steps,
|
||||
guidance_scale=float(async_task.cfg_scale),
|
||||
seed=task_item['task_seed'],
|
||||
shift=float(async_task.adaptive_cfg),
|
||||
text_encoder_override=async_task.zit_text_encoder,
|
||||
vae_override=async_task.zit_vae,
|
||||
)
|
||||
_save_task_result(task_item, np.array(pil_image))
|
||||
except Exception as single_error:
|
||||
progressbar(async_task, 100, f'Z-Image generation failed: {single_error}')
|
||||
print(f'[Z-Image POC] Generation failed: {single_error}')
|
||||
return
|
||||
else:
|
||||
progressbar(async_task, 100, f'Z-Image generation failed: {e}')
|
||||
print(f'[Z-Image POC] Generation failed: {e}')
|
||||
return
|
||||
|
||||
task = dict(
|
||||
task_seed=task_seed,
|
||||
log_positive_prompt=task_prompt,
|
||||
log_negative_prompt=task_negative_prompt,
|
||||
expansion='',
|
||||
styles=[],
|
||||
positive=[task_prompt],
|
||||
negative=[task_negative_prompt],
|
||||
)
|
||||
progress = int((i + 1) * 100 / float(max(async_task.image_number, 1)))
|
||||
img_paths = save_and_log(async_task, height, [np_image], task, False, width, [], persist_image=True)
|
||||
yield_result(async_task, img_paths, progress, async_task.black_out_nsfw, False,
|
||||
do_not_show_finished_images=async_task.disable_intermediate_results)
|
||||
task_pos += len(batch)
|
||||
finally:
|
||||
if unloaded_standard_models:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -461,6 +461,17 @@ def _zimage_fp16_quant_accum_mode() -> str:
|
|||
return mode
|
||||
|
||||
|
||||
def _zimage_quant_backend() -> str:
|
||||
raw = os.environ.get("FOOOCUS_ZIMAGE_QUANT_BACKEND", "runtime").strip().lower()
|
||||
if raw in ("runtime", "comfy_ops"):
|
||||
return raw
|
||||
_warn_once_env(
|
||||
"FOOOCUS_ZIMAGE_QUANT_BACKEND",
|
||||
f"[Z-Image POC] Ignoring invalid FOOOCUS_ZIMAGE_QUANT_BACKEND='{raw}'. Expected: runtime|comfy_ops.",
|
||||
)
|
||||
return "runtime"
|
||||
|
||||
|
||||
def _zimage_prewarm_enabled() -> bool:
|
||||
return _truthy_env("FOOOCUS_ZIMAGE_PREWARM", "0")
|
||||
|
||||
|
|
@ -3247,6 +3258,158 @@ def _load_component_override_from_file(
|
|||
"skipped_type_counts": skipped_type_counts,
|
||||
}
|
||||
|
||||
def _import_comfy_ops_module():
|
||||
import sys
|
||||
|
||||
comfy_root = os.environ.get("FOOOCUS_ZIMAGE_COMFY_ROOT", "").strip()
|
||||
tried_root = None
|
||||
initial_error = None
|
||||
|
||||
try:
|
||||
import comfy.ops as comfy_ops # type: ignore
|
||||
return comfy_ops, None
|
||||
except Exception as e:
|
||||
initial_error = e
|
||||
|
||||
if comfy_root:
|
||||
tried_root = os.path.abspath(os.path.expanduser(comfy_root))
|
||||
if os.path.isdir(tried_root) and tried_root not in sys.path:
|
||||
sys.path.insert(0, tried_root)
|
||||
try:
|
||||
import comfy.ops as comfy_ops # type: ignore
|
||||
print(f"[Z-Image POC] Loaded Comfy ops from FOOOCUS_ZIMAGE_COMFY_ROOT={tried_root}")
|
||||
return comfy_ops, None
|
||||
except Exception as e:
|
||||
return None, f"{initial_error}; retry_with_root={tried_root}: {e}"
|
||||
|
||||
return None, str(initial_error)
|
||||
|
||||
def _install_comfy_ops_quant_modules(component_module, remapped_sd: dict, compute_dtype) -> dict:
|
||||
bases = sorted(
|
||||
{
|
||||
key[: -len(".comfy_quant")]
|
||||
for key in remapped_sd.keys()
|
||||
if key.endswith(".comfy_quant") and f"{key[: -len('.comfy_quant')]}.weight" in remapped_sd
|
||||
}
|
||||
)
|
||||
if not bases:
|
||||
return {
|
||||
"layers": 0,
|
||||
"replaced": 0,
|
||||
"skipped": 0,
|
||||
"float8": 0,
|
||||
"nvfp4": 0,
|
||||
"backend": "comfy_ops",
|
||||
}
|
||||
|
||||
comfy_ops, import_error = _import_comfy_ops_module()
|
||||
if comfy_ops is None:
|
||||
return {
|
||||
"layers": len(bases),
|
||||
"replaced": 0,
|
||||
"skipped": len(bases),
|
||||
"float8": 0,
|
||||
"nvfp4": 0,
|
||||
"replaced_bases": set(),
|
||||
"skipped_unresolved": len(bases),
|
||||
"skipped_type_counts": {"import_error": len(bases)},
|
||||
"backend": "comfy_ops",
|
||||
"error": f"Unable to import comfy.ops ({import_error})",
|
||||
}
|
||||
|
||||
try:
|
||||
mixed_ops = comfy_ops.mixed_precision_ops(
|
||||
quant_config={},
|
||||
compute_dtype=compute_dtype,
|
||||
full_precision_mm=False,
|
||||
disabled=[],
|
||||
)
|
||||
linear_cls = getattr(mixed_ops, "Linear", None)
|
||||
if linear_cls is None:
|
||||
raise RuntimeError("mixed_precision_ops did not expose Linear class.")
|
||||
except Exception as e:
|
||||
return {
|
||||
"layers": len(bases),
|
||||
"replaced": 0,
|
||||
"skipped": len(bases),
|
||||
"float8": 0,
|
||||
"nvfp4": 0,
|
||||
"replaced_bases": set(),
|
||||
"skipped_unresolved": len(bases),
|
||||
"skipped_type_counts": {"init_error": len(bases)},
|
||||
"backend": "comfy_ops",
|
||||
"error": f"Failed to initialize Comfy mixed_precision_ops ({e})",
|
||||
}
|
||||
|
||||
replaced = 0
|
||||
skipped = 0
|
||||
float8_layers = 0
|
||||
nvfp4_layers = 0
|
||||
replaced_bases = set()
|
||||
skipped_type_counts = {}
|
||||
skipped_unresolved = 0
|
||||
|
||||
def _mark_skipped_type(name: str):
|
||||
skipped_type_counts[name] = skipped_type_counts.get(name, 0) + 1
|
||||
|
||||
for base in bases:
|
||||
conf = _decode_comfy_quant_entry(remapped_sd.get(f"{base}.comfy_quant"))
|
||||
fmt = str(conf.get("format", "")).lower() if isinstance(conf, dict) else ""
|
||||
if fmt in ("float8_e4m3fn", "float8_e5m2"):
|
||||
float8_layers += 1
|
||||
elif fmt == "nvfp4":
|
||||
nvfp4_layers += 1
|
||||
|
||||
try:
|
||||
target = _resolve_module(component_module, base)
|
||||
except Exception:
|
||||
skipped += 1
|
||||
skipped_unresolved += 1
|
||||
continue
|
||||
|
||||
if isinstance(target, linear_cls):
|
||||
replaced += 1
|
||||
replaced_bases.add(base)
|
||||
continue
|
||||
if not _is_linear_like_module(target):
|
||||
skipped += 1
|
||||
_mark_skipped_type(type(target).__name__)
|
||||
continue
|
||||
|
||||
try:
|
||||
weight = getattr(target, "weight", None)
|
||||
bias = getattr(target, "bias", None)
|
||||
in_features = int(getattr(target, "in_features", weight.shape[1]))
|
||||
out_features = int(getattr(target, "out_features", weight.shape[0]))
|
||||
target_device = getattr(weight, "device", None)
|
||||
replacement = linear_cls(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias is not None,
|
||||
device=target_device,
|
||||
dtype=compute_dtype,
|
||||
)
|
||||
except Exception:
|
||||
skipped += 1
|
||||
_mark_skipped_type(type(target).__name__)
|
||||
continue
|
||||
|
||||
_set_module(component_module, base, replacement)
|
||||
replaced += 1
|
||||
replaced_bases.add(base)
|
||||
|
||||
return {
|
||||
"layers": len(bases),
|
||||
"replaced": replaced,
|
||||
"skipped": skipped,
|
||||
"float8": float8_layers,
|
||||
"nvfp4": nvfp4_layers,
|
||||
"replaced_bases": replaced_bases,
|
||||
"skipped_unresolved": skipped_unresolved,
|
||||
"skipped_type_counts": skipped_type_counts,
|
||||
"backend": "comfy_ops",
|
||||
}
|
||||
|
||||
def _normalize_legacy_scaled_fp8_weights(sd: dict) -> tuple[dict, dict]:
|
||||
converted = dict(sd)
|
||||
migrated = 0
|
||||
|
|
@ -3481,6 +3644,7 @@ def _load_component_override_from_file(
|
|||
component_name in ("text_encoder", "transformer")
|
||||
and _truthy_env("FOOOCUS_ZIMAGE_COMFY_RUNTIME_QUANT", "1")
|
||||
)
|
||||
runtime_quant_backend = _zimage_quant_backend()
|
||||
runtime_stats = {"layers": 0, "replaced": 0, "skipped": 0, "float8": 0, "nvfp4": 0}
|
||||
if runtime_quant_enabled:
|
||||
remapped_candidate = _remap_state_dict_to_model_keys(
|
||||
|
|
@ -3503,7 +3667,26 @@ def _load_component_override_from_file(
|
|||
f"[Z-Image POC] Native FP8 synth skipped non-linear layers for {component_name}: "
|
||||
f"skipped_non_linear={synth_stats['skipped_non_linear']}."
|
||||
)
|
||||
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
|
||||
if runtime_quant_backend == "comfy_ops":
|
||||
compute_dtype = getattr(component, "dtype", None)
|
||||
if compute_dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
||||
compute_dtype = torch.bfloat16
|
||||
runtime_stats = _install_comfy_ops_quant_modules(component, remapped_candidate, compute_dtype)
|
||||
if runtime_stats.get("error"):
|
||||
print(
|
||||
f"[Z-Image POC] Comfy quant backend unavailable for {component_name}: "
|
||||
f"{runtime_stats['error']}. Falling back to runtime backend."
|
||||
)
|
||||
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
|
||||
runtime_stats["backend"] = "runtime"
|
||||
else:
|
||||
print(
|
||||
f"[Z-Image POC] Using experimental Comfy mixed_precision_ops backend for {component_name} "
|
||||
f"(compute_dtype={compute_dtype})."
|
||||
)
|
||||
else:
|
||||
runtime_stats = _install_comfy_runtime_quant_modules(component, remapped_candidate)
|
||||
runtime_stats["backend"] = "runtime"
|
||||
if runtime_stats["layers"] > 0 and runtime_stats["replaced"] > 0:
|
||||
replaced_bases = runtime_stats.get("replaced_bases", set())
|
||||
quant_bases = {
|
||||
|
|
@ -3523,7 +3706,7 @@ def _load_component_override_from_file(
|
|||
print(
|
||||
f"[Z-Image POC] Runtime Comfy quant enabled for {component_name}: "
|
||||
f"layers={runtime_stats['layers']}, fp8={runtime_stats['float8']}, "
|
||||
f"nvfp4={runtime_stats['nvfp4']}."
|
||||
f"nvfp4={runtime_stats['nvfp4']}, backend={runtime_stats.get('backend', 'runtime')}."
|
||||
)
|
||||
else:
|
||||
skipped_types = runtime_stats.get("skipped_type_counts", {})
|
||||
|
|
@ -3536,7 +3719,7 @@ def _load_component_override_from_file(
|
|||
f"(replaced={runtime_stats['replaced']}/{runtime_stats['layers']}, "
|
||||
f"unmapped_dequantized={runtime_stats.get('unmapped_dequantized', 0)}, "
|
||||
f"skipped_unresolved={runtime_stats.get('skipped_unresolved', 0)}"
|
||||
f"{skipped_type_summary}, "
|
||||
f"{skipped_type_summary}, backend={runtime_stats.get('backend', 'runtime')}, "
|
||||
"unmapped layers keep eager load path)."
|
||||
)
|
||||
|
||||
|
|
@ -3903,9 +4086,11 @@ def generate_zimage(
|
|||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seeds: Optional[list[int]] = None,
|
||||
shift: float = 3.0,
|
||||
text_encoder_override: Optional[str] = None,
|
||||
vae_override: Optional[str] = None,
|
||||
return_images: bool = False,
|
||||
):
|
||||
import torch
|
||||
|
||||
|
|
@ -4022,7 +4207,21 @@ def generate_zimage(
|
|||
_PIPELINE_CACHE.pop(cache_key, None)
|
||||
_clear_prompt_cache_for_pipeline(cache_key)
|
||||
raise
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
seed_list = [int(seed)]
|
||||
if seeds:
|
||||
parsed = []
|
||||
for s in seeds:
|
||||
try:
|
||||
parsed.append(int(s))
|
||||
except Exception:
|
||||
continue
|
||||
if parsed:
|
||||
seed_list = parsed
|
||||
|
||||
if len(seed_list) <= 1:
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
|
||||
else:
|
||||
generator = [torch.Generator(device=generator_device).manual_seed(s) for s in seed_list]
|
||||
stage_times["runtime_prep"] = time.perf_counter() - stage_start
|
||||
|
||||
neg_key = negative_prompt if use_cfg else ""
|
||||
|
|
@ -4071,7 +4270,7 @@ def generate_zimage(
|
|||
num_inference_steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
num_images_per_prompt=1,
|
||||
num_images_per_prompt=max(1, len(seed_list)),
|
||||
cfg_normalization=False,
|
||||
cfg_truncation=1.0,
|
||||
max_sequence_length=max_sequence_length,
|
||||
|
|
@ -4081,7 +4280,7 @@ def generate_zimage(
|
|||
print(
|
||||
f"[Z-Image POC] Runtime params: steps={steps}, guidance={guidance_scale}, shift={shift}, "
|
||||
f"size={call_kwargs['width']}x{call_kwargs['height']}, max_seq={max_sequence_length}, offload={used_offload}, "
|
||||
f"dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}"
|
||||
f"batch={call_kwargs['num_images_per_prompt']}, dtype={getattr(pipeline.transformer, 'dtype', 'n/a')}, profile={profile}"
|
||||
)
|
||||
|
||||
output = None
|
||||
|
|
@ -4112,12 +4311,21 @@ def generate_zimage(
|
|||
output = _run_pipeline_call(pipeline, call_kwargs)
|
||||
if _zimage_black_image_retry_enabled() and not black_retry_used:
|
||||
try:
|
||||
candidate = output.images[0]
|
||||
candidates = list(getattr(output, "images", []) or [])
|
||||
except Exception:
|
||||
candidate = None
|
||||
if candidate is not None:
|
||||
is_black, black_info = _is_suspected_black_image(candidate)
|
||||
if is_black:
|
||||
candidates = []
|
||||
black_entries = []
|
||||
for idx, candidate in enumerate(candidates):
|
||||
try:
|
||||
is_black, black_info = _is_suspected_black_image(candidate)
|
||||
except Exception:
|
||||
is_black, black_info = False, None
|
||||
if is_black and black_info is not None:
|
||||
black_entries.append((idx, black_info))
|
||||
if black_entries:
|
||||
first_black_idx, first_black_info = black_entries[0]
|
||||
is_batch_black = len(candidates) > 1 and len(black_entries) == len(candidates)
|
||||
if len(candidates) == 1 or is_batch_black:
|
||||
black_retry_used = True
|
||||
strategy = str(getattr(pipeline, "_zimage_xformers_strategy", "unknown"))
|
||||
transformer = getattr(pipeline, "transformer", None)
|
||||
|
|
@ -4127,7 +4335,8 @@ def generate_zimage(
|
|||
if strict_fp16:
|
||||
print(
|
||||
f"[Z-Image POC] Suspected black output detected "
|
||||
f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, "
|
||||
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
|
||||
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
|
||||
f"attn={strategy}, dtype={transformer_dtype}). Strict FP16 mode enabled; no fallback."
|
||||
)
|
||||
raise RuntimeError(
|
||||
|
|
@ -4135,7 +4344,8 @@ def generate_zimage(
|
|||
)
|
||||
print(
|
||||
f"[Z-Image POC] Suspected black output detected "
|
||||
f"(mean={black_info['mean']:.2f}, max={black_info['max']:.0f}, std={black_info['std']:.2f}, "
|
||||
f"(index={first_black_idx}, mean={first_black_info['mean']:.2f}, "
|
||||
f"max={first_black_info['max']:.0f}, std={first_black_info['std']:.2f}, "
|
||||
f"attn={strategy}, dtype={transformer_dtype}). Retrying once with safer runtime."
|
||||
)
|
||||
|
||||
|
|
@ -4177,7 +4387,12 @@ def generate_zimage(
|
|||
if changed:
|
||||
if generator_device == "cuda":
|
||||
_cleanup_memory(cuda=True, aggressive=True)
|
||||
call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
if len(seed_list) <= 1:
|
||||
call_kwargs["generator"] = torch.Generator(device=generator_device).manual_seed(seed_list[0])
|
||||
else:
|
||||
call_kwargs["generator"] = [
|
||||
torch.Generator(device=generator_device).manual_seed(s) for s in seed_list
|
||||
]
|
||||
original_output = output
|
||||
try:
|
||||
output = _run_pipeline_call(pipeline, call_kwargs)
|
||||
|
|
@ -4187,9 +4402,15 @@ def generate_zimage(
|
|||
f"[Z-Image POC] Black-image retry failed ({retry_error}); keeping original output."
|
||||
)
|
||||
try:
|
||||
retry_image = output.images[0]
|
||||
retry_black, retry_info = _is_suspected_black_image(retry_image)
|
||||
if retry_black:
|
||||
retry_candidates = list(getattr(output, "images", []) or [])
|
||||
retry_black_any = False
|
||||
retry_info = None
|
||||
for retry_image in retry_candidates:
|
||||
retry_black, retry_info = _is_suspected_black_image(retry_image)
|
||||
if retry_black:
|
||||
retry_black_any = True
|
||||
break
|
||||
if retry_black_any and retry_info is not None:
|
||||
print(
|
||||
f"[Z-Image POC] Black-image retry remained near-black "
|
||||
f"(mean={retry_info['mean']:.2f}, max={retry_info['max']:.0f})."
|
||||
|
|
@ -4202,6 +4423,12 @@ def generate_zimage(
|
|||
pass
|
||||
else:
|
||||
print("[Z-Image POC] No safe retry knobs available; keeping original output.")
|
||||
elif black_entries:
|
||||
# For batches, retry only if every output is black. Mixed batches are preserved.
|
||||
print(
|
||||
f"[Z-Image POC] Batch output has {len(black_entries)}/{len(candidates)} near-black images; "
|
||||
"keeping batch output."
|
||||
)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
msg = str(e).lower()
|
||||
|
|
@ -4272,10 +4499,14 @@ def generate_zimage(
|
|||
stage_times["pipeline_call"] = time.perf_counter() - call_start
|
||||
|
||||
stage_start = time.perf_counter()
|
||||
image = output.images[0]
|
||||
images = list(getattr(output, "images", []) or [])
|
||||
del output
|
||||
stage_times["extract_image"] = time.perf_counter() - stage_start
|
||||
return image
|
||||
if return_images:
|
||||
return images
|
||||
if not images:
|
||||
raise RuntimeError("Z-Image pipeline returned no images.")
|
||||
return images[0]
|
||||
except Exception as e:
|
||||
error_name = type(e).__name__
|
||||
if "pipeline_call" not in stage_times:
|
||||
|
|
|
|||
Loading…
Reference in New Issue