mirror of https://github.com/lllyasviel/Fooocus
fixes
parent
aa869717a7
commit
eaad4c7ed9
|
|
@ -544,6 +544,7 @@ def _set_generation_random_source(
|
|||
pipeline,
|
||||
generator_device: str,
|
||||
use_alt_path: bool,
|
||||
latents_device: Optional[str] = None,
|
||||
):
|
||||
import torch
|
||||
|
||||
|
|
@ -554,7 +555,7 @@ def _set_generation_random_source(
|
|||
seed_list=seed_list,
|
||||
width=int(call_kwargs.get("width", 0)),
|
||||
height=int(call_kwargs.get("height", 0)),
|
||||
generator_device=generator_device,
|
||||
generator_device=(latents_device or generator_device),
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -4935,12 +4936,17 @@ def _generate_zimage_impl(
|
|||
seed_list = parsed
|
||||
|
||||
generator = None
|
||||
alt_latents_device_candidates: list[str] = []
|
||||
alt_latents_device_index = 0
|
||||
if _use_alt_path:
|
||||
_ensure_alt_path_prerequisites(
|
||||
pipeline=pipeline,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
for candidate in (generator_device, "cpu"):
|
||||
if candidate not in alt_latents_device_candidates:
|
||||
alt_latents_device_candidates.append(candidate)
|
||||
else:
|
||||
if len(seed_list) <= 1:
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed_list[0])
|
||||
|
|
@ -5011,6 +5017,7 @@ def _generate_zimage_impl(
|
|||
pipeline=pipeline,
|
||||
generator_device=generator_device,
|
||||
use_alt_path=True,
|
||||
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
|
||||
)
|
||||
else:
|
||||
call_kwargs["generator"] = generator
|
||||
|
|
@ -5052,6 +5059,7 @@ def _generate_zimage_impl(
|
|||
pipeline=pipeline,
|
||||
generator_device=generator_device,
|
||||
use_alt_path=True,
|
||||
latents_device=alt_latents_device_candidates[alt_latents_device_index] if alt_latents_device_candidates else None,
|
||||
)
|
||||
_prepare_granular_pipeline_call(
|
||||
pipeline,
|
||||
|
|
@ -5143,6 +5151,7 @@ def _generate_zimage_impl(
|
|||
pipeline=pipeline,
|
||||
generator_device=generator_device,
|
||||
use_alt_path=_use_alt_path,
|
||||
latents_device=alt_latents_device_candidates[alt_latents_device_index] if _use_alt_path and alt_latents_device_candidates else None,
|
||||
)
|
||||
original_output = output
|
||||
try:
|
||||
|
|
@ -5188,6 +5197,64 @@ def _generate_zimage_impl(
|
|||
break
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
device_mismatch = (
|
||||
"expected all tensors to be on the same device" in msg
|
||||
and "cuda" in msg
|
||||
and "cpu" in msg
|
||||
)
|
||||
if _use_alt_path and device_mismatch and attempt < (max_attempts - 1):
|
||||
if alt_latents_device_candidates and (alt_latents_device_index + 1) < len(alt_latents_device_candidates):
|
||||
alt_latents_device_index += 1
|
||||
retry_latent_device = alt_latents_device_candidates[alt_latents_device_index]
|
||||
print(
|
||||
"[Z-Image POC] Alternate path device mismatch detected; "
|
||||
f"retrying with latents on {retry_latent_device}."
|
||||
)
|
||||
continue
|
||||
|
||||
deep_state = getattr(pipeline, "_zimage_deep_patcher_state", None)
|
||||
if isinstance(deep_state, dict):
|
||||
print(
|
||||
"[Z-Image POC] Alternate path device mismatch persisted; "
|
||||
"disabling deep patcher and retrying with non-deep offload."
|
||||
)
|
||||
pipeline._zimage_deep_patcher_blocked = True
|
||||
_disable_deep_patcher_offload(pipeline, target_device="cpu")
|
||||
try:
|
||||
fallback_free_gb, fallback_total_gb = _cuda_mem_info_gb()
|
||||
fallback_pressure = (fallback_free_gb / fallback_total_gb) if fallback_total_gb > 0 else 0.0
|
||||
generator_device, used_offload = _apply_memory_mode(
|
||||
pipeline=pipeline,
|
||||
device="cuda",
|
||||
target_mode="model_offload",
|
||||
total_vram_gb=fallback_total_gb,
|
||||
free_vram_gb=fallback_free_gb,
|
||||
pressure=fallback_pressure,
|
||||
profile=profile,
|
||||
reason="alternate path device mismatch",
|
||||
allow_relax=True,
|
||||
)
|
||||
_PIPELINE_CACHE[cache_key] = (pipeline, generator_device, used_offload)
|
||||
except Exception as fallback_error:
|
||||
print(
|
||||
"[Z-Image POC] Failed to switch from deep patcher after alternate-path device mismatch: "
|
||||
f"{fallback_error}"
|
||||
)
|
||||
alt_latents_device_candidates = []
|
||||
for candidate in (generator_device, "cpu"):
|
||||
if candidate not in alt_latents_device_candidates:
|
||||
alt_latents_device_candidates.append(candidate)
|
||||
alt_latents_device_index = 0
|
||||
call_kwargs["prompt_embeds"] = [
|
||||
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
||||
for x in call_kwargs.get("prompt_embeds", [])
|
||||
]
|
||||
call_kwargs["negative_prompt_embeds"] = [
|
||||
x.to(device=generator_device, dtype=pipeline.transformer.dtype)
|
||||
for x in call_kwargs.get("negative_prompt_embeds", [])
|
||||
]
|
||||
continue
|
||||
|
||||
deep_generator_mismatch = (
|
||||
"cannot generate a cpu tensor from a generator of type cuda" in msg
|
||||
or ("generator of type cuda" in msg and "cpu tensor" in msg)
|
||||
|
|
|
|||
Loading…
Reference in New Issue