pull/4166/head
Developer 2026-02-20 14:26:07 +02:00
parent aa869717a7
commit eaad4c7ed9
1 changed files with 68 additions and 1 deletions

View File

@ -544,6 +544,7 @@ def _set_generation_random_source(
pipeline,
generator_device: str,
use_alt_path: bool,
latents_device: Optional[str] = None,
):
import torch
@ -554,7 +555,7 @@ def _set_generation_random_source(
seed_list=seed_list,
width=int(call_kwargs.get("width", 0)),
height=int(call_kwargs.get("height", 0)),
generator_device=generator_device,
generator_device=(latents_device or generator_device),
)
return None
@ -4935,12 +4936,17 @@ def _generate_zimage_impl(
seed_list = parsed
generator = None
alt_latents_device_candidates: list[str] = []
alt_latents_device_index = 0
if _use_alt_path:
_ensure_alt_path_prerequisites(
pipeline=pipeline,
width=width,
height=height,
)
for candidate in (generator_device, "cpu"):
if candidate not in alt_latents_device_candidates:
alt_latents_device_candidates.append(candidate)
else:
if len(seed_list) <= 1:
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
@ -5011,6 +5017,7 @@ def _generate_zimage_impl(
pipeline=pipeline,
generator_device=generator_device,
use_alt_path=True,
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
)
else:
call_kwargs["generator"] = generator
@ -5052,6 +5059,7 @@ def _generate_zimage_impl(
pipeline=pipeline,
generator_device=generator_device,
use_alt_path=True,
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
)
_prepare_granular_pipeline_call(
pipeline,
@ -5143,6 +5151,7 @@ def _generate_zimage_impl(
pipeline=pipeline,
generator_device=generator_device,
use_alt_path=_use_alt_path,
latents_device=alt_latents_device_candidates[alt_latents_device_index] if _use_alt_path and alt_latents_device_candidates else None,
)
original_output = output
try:
@ -5188,6 +5197,64 @@ def _generate_zimage_impl(
break
except Exception as e:
msg = str(e).lower()
device_mismatch = (
"expected all tensors to be on the same device" in msg
and "cuda" in msg
and "cpu" in msg
)
if _use_alt_path and device_mismatch and attempt < (max_attempts - 1):
if alt_latents_device_candidates and (alt_latents_device_index + 1) < len(alt_latents_device_candidates):
alt_latents_device_index += 1
retry_latent_device = alt_latents_device_candidates[alt_latents_device_index]
print(
"[Z-Image POC] Alternate path device mismatch detected; "
f"retrying with latents on {retry_latent_device}."
)
continue
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
if isinstance(deep_state, dict):
print(
"[Z-Image POC] Alternate path device mismatch persisted; "
"disabling deep patcher and retrying with non-deep offload."
)
pipeline._zimage_deep_patcher_blocked = True
_disable_deep_patcher_offload(pipeline, target_device="cpu")
try:
fallback_free_gb, fallback_total_gb = _cuda_mem_info_gb()
fallback_pressure = (fallback_free_gb / fallback_total_gb) if fallback_total_gb > 0 else 0.0
generator_device, used_offload = _apply_memory_mode(
pipeline=pipeline,
device="cuda",
target_mode="model_offload",
total_vram_gb=fallback_total_gb,
free_vram_gb=fallback_free_gb,
pressure=fallback_pressure,
profile=profile,
reason="alternate path device mismatch",
allow_relax=True,
)
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
except Exception as fallback_error:
print(
"[Z-Image POC] Failed to switch from deep patcher after alternate-path device mismatch: "
f"{fallback_error}"
)
alt_latents_device_candidates = []
for candidate in (generator_device, "cpu"):
if candidate not in alt_latents_device_candidates:
alt_latents_device_candidates.append(candidate)
alt_latents_device_index = 0
call_kwargs["prompt_embeds"] = [
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
for x in call_kwargs.get("prompt_embeds", [])
]
call_kwargs["negative_prompt_embeds"] = [
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
for x in call_kwargs.get("negative_prompt_embeds", [])
]
continue
deep_generator_mismatch = (
"cannot generate a cpu tensor from a generator of type cuda" in msg
or ("generator of type cuda" in msg and "cpu tensor" in msg)