mirror of https://github.com/vladmandic/automatic
feat(taesd): add FLUX.2 preview support
Enable live preview during FLUX.2 and FLUX.2 Klein image generation using the TAE FLUX.2 decoder from madebyollin/taesd. - Add dedicated TAE entries (FLUX.1, FLUX.2, SD3) that auto-select based on model type, making the dropdown only affect SD/SDXL models - Add FLUX.2 latent unpacking in callback to convert packed [B, seq_len, 128] format to spatial [B, 32, H, W] for preview - Support FLUX.2's 32 latent channels (vs 16 for FLUX.1/SD3)pull/4553/head
parent
fe99d3fe5d
commit
c605a1bb62
|
|
@ -116,7 +116,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
|
|||
if current_noise_pred is None:
|
||||
current_noise_pred = kwargs.get("predicted_image_embedding", None)
|
||||
|
||||
if hasattr(pipe, "_unpack_latents") and hasattr(pipe, "vae_scale_factor"): # FLUX
|
||||
if hasattr(pipe, "_unpack_latents") and hasattr(pipe, "vae_scale_factor"): # FLUX.1
|
||||
if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) and p.is_hr_pass:
|
||||
width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0))
|
||||
height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0))
|
||||
|
|
@ -128,6 +128,37 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
|
|||
shared.state.current_noise_pred = pipe._unpack_latents(current_noise_pred, height, width, pipe.vae_scale_factor) # pylint: disable=protected-access
|
||||
else:
|
||||
shared.state.current_noise_pred = current_noise_pred
|
||||
elif hasattr(pipe, "_unpatchify_latents"): # FLUX.2 - unpack [B, seq, patch_ch] to [B, ch, H, W]
|
||||
# Get dimensions for unpacking, same logic as FLUX.1
|
||||
vae_scale = getattr(pipe, 'vae_scale_factor', 8)
|
||||
if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) and p.is_hr_pass:
|
||||
width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0))
|
||||
height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0))
|
||||
else:
|
||||
width = getattr(p, 'width', 1024)
|
||||
height = getattr(p, 'height', 1024)
|
||||
latents = kwargs['latents']
|
||||
if len(latents.shape) == 3: # packed format [B, seq_len, patch_channels]
|
||||
b, seq_len, patch_ch = latents.shape
|
||||
channels = patch_ch // 4 # 4 = 2x2 patch
|
||||
h_patches = height // vae_scale // 2
|
||||
w_patches = width // vae_scale // 2
|
||||
if h_patches * w_patches != seq_len: # fallback to square assumption
|
||||
h_patches = w_patches = int(seq_len ** 0.5)
|
||||
# [B, h*w, C*4] -> [B, h, w, C, 2, 2] -> [B, C, h, 2, w, 2] -> [B, C, H, W]
|
||||
latents = latents.view(b, h_patches, w_patches, channels, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5).reshape(b, channels, h_patches * 2, w_patches * 2)
|
||||
shared.state.current_latent = latents
|
||||
if current_noise_pred is not None and len(current_noise_pred.shape) == 3:
|
||||
b, seq_len, patch_ch = current_noise_pred.shape
|
||||
channels = patch_ch // 4
|
||||
h_patches = height // vae_scale // 2
|
||||
w_patches = width // vae_scale // 2
|
||||
if h_patches * w_patches != seq_len:
|
||||
h_patches = w_patches = int(seq_len ** 0.5)
|
||||
current_noise_pred = current_noise_pred.view(b, h_patches, w_patches, channels, 2, 2)
|
||||
current_noise_pred = current_noise_pred.permute(0, 3, 1, 4, 2, 5).reshape(b, channels, h_patches * 2, w_patches * 2)
|
||||
shared.state.current_noise_pred = current_noise_pred
|
||||
else:
|
||||
shared.state.current_latent = kwargs['latents']
|
||||
shared.state.current_noise_pred = current_noise_pred
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ TAESD_MODELS = {
|
|||
'TAESD 1.2 Chocolate-Dipped Shortbread': { 'fn': 'taesd_12_', 'uri': 'https://github.com/madebyollin/taesd/raw/8909b44e3befaa0efa79c5791e4fe1c4d4f7884e', 'model': None },
|
||||
'TAESD 1.1 Fruit Loops': { 'fn': 'taesd_11_', 'uri': 'https://github.com/madebyollin/taesd/raw/3e8a8a2ab4ad4079db60c1c7dc1379b4cc0c6b31', 'model': None },
|
||||
'TAESD 1.0': { 'fn': 'taesd_10_', 'uri': 'https://github.com/madebyollin/taesd/raw/88012e67cf0454e6d90f98911fe9d4aef62add86', 'model': None },
|
||||
'TAE FLUX.1': { 'fn': 'taef1.pth', 'uri': 'https://github.com/madebyollin/taesd/raw/main/taef1_decoder.pth', 'model': None },
|
||||
'TAE FLUX.2': { 'fn': 'taef2.pth', 'uri': 'https://github.com/madebyollin/taesd/raw/main/taef2_decoder.pth', 'model': None },
|
||||
'TAE SD3': { 'fn': 'taesd3.pth', 'uri': 'https://github.com/madebyollin/taesd/raw/main/taesd3_decoder.pth', 'model': None },
|
||||
'TAE HunyuanVideo': { 'fn': 'taehv.pth', 'uri': 'https://github.com/madebyollin/taehv/raw/refs/heads/main/taehv.pth', 'model': None },
|
||||
'TAE WanVideo': { 'fn': 'taew1.pth', 'uri': 'https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth', 'model': None },
|
||||
'TAE MochiVideo': { 'fn': 'taem1.pth', 'uri': 'https://github.com/madebyollin/taem1/raw/refs/heads/main/taem1.pth', 'model': None },
|
||||
|
|
@ -38,7 +41,7 @@ prev_cls = ''
|
|||
prev_type = ''
|
||||
prev_model = ''
|
||||
lock = threading.Lock()
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'zimage', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'cosmos', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat', 'omnigen2', 'flite', 'ovis', 'kandinsky5', 'glmimage', 'cogview3', 'cogview4']
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'f2', 'h1', 'zimage', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'cosmos', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat', 'omnigen2', 'flite', 'ovis', 'kandinsky5', 'glmimage', 'cogview3', 'cogview4']
|
||||
|
||||
|
||||
def warn_once(msg, variant=None):
|
||||
|
|
@ -59,8 +62,14 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
model_cls = 'sd'
|
||||
elif model_cls in {'pixartsigma', 'hunyuandit', 'omnigen', 'auraflow'}:
|
||||
model_cls = 'sdxl'
|
||||
elif model_cls in {'h1', 'zimage', 'lumina2', 'chroma', 'longcat', 'omnigen2', 'flite', 'ovis', 'kandinsky5', 'glmimage', 'cogview3', 'cogview4'}:
|
||||
elif model_cls in {'f1', 'h1', 'zimage', 'lumina2', 'chroma', 'longcat', 'omnigen2', 'flite', 'ovis', 'kandinsky5', 'glmimage', 'cogview3', 'cogview4'}:
|
||||
model_cls = 'f1'
|
||||
variant = 'TAE FLUX.1'
|
||||
elif model_cls == 'f2':
|
||||
model_cls = 'f2'
|
||||
variant = 'TAE FLUX.2'
|
||||
elif model_cls == 'sd3':
|
||||
variant = 'TAE SD3'
|
||||
elif model_cls in {'wanai', 'qwen', 'chrono', 'cosmos'}:
|
||||
variant = variant or 'TAE WanVideo'
|
||||
elif model_cls not in supported:
|
||||
|
|
@ -149,7 +158,12 @@ def decode(latents):
|
|||
dtype = devices.dtype_vae if devices.dtype_vae != torch.bfloat16 else torch.float16 # taesd does not support bf16
|
||||
tensor = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
|
||||
tensor = tensor.detach().clone().to(devices.device, dtype=dtype)
|
||||
if variant.startswith('TAESD'):
|
||||
shared.log.debug(f'Decode: type="taesd" variant="{variant}" input={latents.shape} tensor={tensor.shape}')
|
||||
# FLUX.2 has 128 latent channels that need reshaping to 32 channels for TAESD
|
||||
if variant == 'TAE FLUX.2' and len(tensor.shape) == 4 and tensor.shape[1] == 128:
|
||||
b, c, h, w = tensor.shape
|
||||
tensor = tensor.reshape(b, 32, h * 2, w * 2)
|
||||
if variant.startswith('TAESD') or variant in {'TAE FLUX.1', 'TAE FLUX.2', 'TAE SD3'}:
|
||||
image = vae.decoder(tensor).clamp(0, 1).detach()
|
||||
image = image[0]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -77,7 +77,11 @@ class TAESD(nn.Module): # pylint: disable=abstract-method
|
|||
self.decoder = self.decoder.to(devices.device, dtype=self.dtype)
|
||||
|
||||
def guess_latent_channels(self, decoder_path, encoder_path):
|
||||
return 16 if ("f1" in encoder_path or "f1" in decoder_path) or ("sd3" in encoder_path or "sd3" in decoder_path) else 4
|
||||
if "f2" in encoder_path or "f2" in decoder_path:
|
||||
return 32 # FLUX.2 uses 32 latent channels
|
||||
if ("f1" in encoder_path or "f1" in decoder_path) or ("sd3" in encoder_path or "sd3" in decoder_path):
|
||||
return 16
|
||||
return 4
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
|
|
|||
Loading…
Reference in New Issue