From eaad4c7ed9727789aa87b933ccb56a4bea3e464a Mon Sep 17 00:00:00 2001 From: Developer Date: Fri, 20 Feb 2026 14:26:07 +0200 Subject: [PATCH] fixes --- modules/zimage_poc.py | 69 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/modules/zimage_poc.py b/modules/zimage_poc.py index 6cbf06e7..4d4aa63b 100644 --- a/modules/zimage_poc.py +++ b/modules/zimage_poc.py @@ -544,6 +544,7 @@ def _set_generation_random_source( pipeline, generator_device: str, use_alt_path: bool, + latents_device: Optional[str] = None, ): import torch @@ -554,7 +555,7 @@ def _set_generation_random_source( seed_list=seed_list, width=int(call_kwargs.get("width", 0)), height=int(call_kwargs.get("height", 0)), - generator_device=generator_device, + generator_device=(latents_device or generator_device), ) return None @@ -4935,12 +4936,17 @@ def _generate_zimage_impl( seed_list = parsed generator = None + alt_latents_device_candidates: list[str] = [] + alt_latents_device_index = 0 if _use_alt_path: _ensure_alt_path_prerequisites( pipeline=pipeline, width=width, height=height, ) + for candidate in (generator_device, "cpu"): + if candidate not in alt_latents_device_candidates: + alt_latents_device_candidates.append(candidate) else: if len(seed_list) <= 1: generator = torch.Generator(device=generator_device).manual_seed(seed_list[0]) @@ -5011,6 +5017,7 @@ def _generate_zimage_impl( pipeline=pipeline, generator_device=generator_device, use_alt_path=True, + latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None, ) else: call_kwargs["generator"] = generator @@ -5052,6 +5059,7 @@ def _generate_zimage_impl( pipeline=pipeline, generator_device=generator_device, use_alt_path=True, + latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None, ) _prepare_granular_pipeline_call( pipeline, @@ -5143,6 +5151,7 @@ def _generate_zimage_impl( pipeline=pipeline, generator_device=generator_device, use_alt_path=_use_alt_path, + latents_device=alt_latents_device_candidates[alt_latents_device_index] if _use_alt_path and alt_latents_device_candidates else None, ) original_output = output try: @@ -5188,6 +5197,64 @@ def _generate_zimage_impl( break except Exception as e: msg = str(e).lower() + device_mismatch = ( + "expected all tensors to be on the same device" in msg + and "cuda" in msg + and "cpu" in msg + ) + if _use_alt_path and device_mismatch and attempt < (max_attempts - 1): + if alt_latents_device_candidates and (alt_latents_device_index + 1) < len(alt_latents_device_candidates): + alt_latents_device_index += 1 + retry_latent_device = alt_latents_device_candidates[alt_latents_device_index] + print( + "[Z-Image POC] Alternate path device mismatch detected; " + f"retrying with latents on {retry_latent_device}." + ) + continue + + deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None) + if isinstance(deep_state, dict): + print( + "[Z-Image POC] Alternate path device mismatch persisted; " + "disabling deep patcher and retrying with non-deep offload." + ) + pipeline._zimage_deep_patcher_blocked = True + _disable_deep_patcher_offload(pipeline, target_device="cpu") + try: + fallback_free_gb, fallback_total_gb = _cuda_mem_info_gb() + fallback_pressure = (fallback_free_gb / fallback_total_gb) if fallback_total_gb > 0 else 0.0 + generator_device, used_offload = _apply_memory_mode( + pipeline=pipeline, + device="cuda", + target_mode="model_offload", + total_vram_gb=fallback_total_gb, + free_vram_gb=fallback_free_gb, + pressure=fallback_pressure, + profile=profile, + reason="alternate path device mismatch", + allow_relax=True, + ) + _PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload) + except Exception as fallback_error: + print( + "[Z-Image POC] Failed to switch from deep patcher after alternate-path device mismatch: " + f"{fallback_error}" + ) + alt_latents_device_candidates = [] + for candidate in (generator_device, "cpu"): + if candidate not in alt_latents_device_candidates: + alt_latents_device_candidates.append(candidate) + alt_latents_device_index = 0 + call_kwargs["prompt_embeds"] = [ + x.to(device=generator_device, dtype=pipeline.transformer.dtype) + for x in call_kwargs.get("prompt_embeds", []) + ] + call_kwargs["negative_prompt_embeds"] = [ + x.to(device=generator_device, dtype=pipeline.transformer.dtype) + for x in call_kwargs.get("negative_prompt_embeds", []) + ] + continue + deep_generator_mismatch = ( "cannot generate a cpu tensor from a generator of type cuda" in msg or ("generator of type cuda" in msg and "cpu tensor" in msg)