fix lora unload and improve preview error handler

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4631/head
Vladimir Mandic 2026-02-07 09:07:34 +00:00
parent 5e2ab3057f
commit e8ff09a2d2
5 changed files with 15 additions and 13 deletions

View File

@ -1,6 +1,6 @@
# Change Log for SD.Next # Change Log for SD.Next
## Update for 2026-02-06 ## Update for 2026-02-07
- **Upscalers** - **Upscalers**
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel) - add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
@ -19,6 +19,7 @@
- fix: improve wildcard weights parsing, thanks @Tillerz - fix: improve wildcard weights parsing, thanks @Tillerz
- fix: ui gallery cace recursive cleanup, thanks @awsr - fix: ui gallery cace recursive cleanup, thanks @awsr
- fix: `anima` model detection - fix: `anima` model detection
- fix: lora unwanted unload
## Update for 2026-02-04 ## Update for 2026-02-04

@ -1 +1 @@
Subproject commit ead16e14410ff177e2e4e105bcbec3eaa737de7d Subproject commit 2b9f7c293f3a2146b991a943397e7567059e73c8

View File

@ -175,6 +175,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool: def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool:
if shared.opts.lora_force_reload: if shared.opts.lora_force_reload:
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=forced')
return True return True
sd_model = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model sd_model = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model
if not hasattr(sd_model, 'loaded_loras'): if not hasattr(sd_model, 'loaded_loras'):
@ -185,14 +186,16 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
exclude = ['none'] exclude = ['none']
key = f'include={",".join(include)}:exclude={",".join(exclude)}' key = f'include={",".join(include)}:exclude={",".join(exclude)}'
loaded = sd_model.loaded_loras.get(key, []) loaded = sd_model.loaded_loras.get(key, [])
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded}')
if len(requested) != len(loaded): if len(requested) != len(loaded):
sd_model.loaded_loras[key] = requested sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
return True return True
for req, load in zip(requested, loaded): for req, load in zip(requested, loaded):
if req != load: if req != load:
sd_model.loaded_loras[key] = requested sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
return True return True
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=same')
return False return False
def activate(self, p, params_list, step=0, include=[], exclude=[]): def activate(self, p, params_list, step=0, include=[], exclude=[]):
@ -245,9 +248,8 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
shared.log.info(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} method={load_method} mode={"fuse" if shared.opts.lora_fuse_native else "backup"} te={te_multipliers} unet={unet_multipliers} time={l.timer.summary}') shared.log.info(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} method={load_method} mode={"fuse" if shared.opts.lora_fuse_native else "backup"} te={te_multipliers} unet={unet_multipliers} time={l.timer.summary}')
def deactivate(self, p, force=False): def deactivate(self, p, force=False):
if len(lora_diffusers.diffuser_loaded) > 0: if len(lora_diffusers.diffuser_loaded) > 0 and (shared.opts.lora_force_reload or force):
if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True): unload_diffusers()
unload_diffusers()
if force: if force:
networks.network_deactivate() networks.network_deactivate()
if self.active and l.debug: if self.active and l.debug:

View File

@ -2,7 +2,7 @@ import time
import threading import threading
from collections import namedtuple from collections import namedtuple
import torch import torch
import torchvision.transforms as T import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from modules import shared, devices, processing, images, sd_samplers, timer from modules import shared, devices, processing, images, sd_samplers, timer
from modules.vae import sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade from modules.vae import sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade
@ -80,12 +80,11 @@ def single_sample_to_image(sample, approximation=None):
else: else:
if x_sample.shape[0] > 4 or x_sample.shape[0] == 4: if x_sample.shape[0] > 4 or x_sample.shape[0] == 4:
return Image.new(mode="RGB", size=(512, 512)) return Image.new(mode="RGB", size=(512, 512))
if x_sample.dtype == torch.bfloat16: x_sample = torch.nan_to_num(x_sample, nan=0.0, posinf=1, neginf=0)
x_sample = x_sample.to(torch.float16) x_sample = (255.0 * x_sample).to(torch.uint8)
if len(x_sample.shape) == 4: if len(x_sample.shape) == 4:
x_sample = x_sample[0] x_sample = x_sample[0]
transform = T.ToPILImage() image = TF.to_pil_image(x_sample)
image = transform(x_sample)
except Exception as e: except Exception as e:
warn_once(f'Preview: {e}') warn_once(f'Preview: {e}')
image = Image.new(mode="RGB", size=(512, 512)) image = Image.new(mode="RGB", size=(512, 512))

View File

@ -158,13 +158,13 @@ def decode(latents):
try: try:
with devices.inference_context(): with devices.inference_context():
t0 = time.time() t0 = time.time()
dtype = devices.dtype_vae if devices.dtype_vae != torch.bfloat16 else torch.float16 # taesd does not support bf16 dtype = devices.dtype_vae if (devices.dtype_vae != torch.bfloat16) else torch.float16 # taesd does not support bf16
tensor = latents.unsqueeze(0) if len(latents.shape) == 3 else latents tensor = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
tensor = tensor.detach().clone().to(devices.device, dtype=dtype) tensor = tensor.detach().clone().to(devices.device, dtype=dtype)
if debug: if debug:
shared.log.debug(f'Decode: type="taesd" variant="{variant}" input={latents.shape} tensor={tensor.shape}') shared.log.debug(f'Decode: type="taesd" variant="{variant}" input={latents.shape} tensor={tensor.shape}')
# Fallback: reshape packed 128-channel latents to 32 channels if not already unpacked # Fallback: reshape packed 128-channel latents to 32 channels if not already unpacked
if variant == 'TAE FLUX.2' and len(tensor.shape) == 4 and tensor.shape[1] == 128: if (variant == 'TAE FLUX.2') and (len(tensor.shape) == 4) and (tensor.shape[1] == 128):
b, _c, h, w = tensor.shape b, _c, h, w = tensor.shape
tensor = tensor.reshape(b, 32, h * 2, w * 2) tensor = tensor.reshape(b, 32, h * 2, w * 2)
if variant.startswith('TAESD') or variant in {'TAE FLUX.1', 'TAE FLUX.2', 'TAE SD3'}: if variant.startswith('TAESD') or variant in {'TAE FLUX.1', 'TAE FLUX.2', 'TAE SD3'}: