From 6e3c6087d4f4e24d8eedd8d5c6e17c9e88eb91ff Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 22 Apr 2026 22:06:10 -0600 Subject: [PATCH] Allow using Flux2 TAE is a normal VAE --- comfy/sd.py | 5 ++++- nodes.py | 53 +++++++++++++++++------------------------------------ 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..6a4eeb001 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -477,7 +477,10 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: - self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] + if isinstance(metadata, dict) and "tae_latent_channels" in metadata: + self.latent_channels = metadata["tae_latent_channels"] + else: + self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() diff --git a/nodes.py b/nodes.py index 299b3d758..8817e78e2 100644 --- a/nodes.py +++ b/nodes.py @@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader): class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] - image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"] + @staticmethod def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") - sdxl_taesd_enc = False - sdxl_taesd_dec = False - sd1_taesd_enc = False - sd1_taesd_dec = False - sd3_taesd_enc = False - sd3_taesd_dec = False - f1_taesd_enc = False - f1_taesd_dec = False - + have_img_encoder, have_img_decoder = set(), set() for v in approx_vaes: - if v.startswith("taesd_decoder."): - sd1_taesd_dec = True - elif v.startswith("taesd_encoder."): - sd1_taesd_enc = True - elif v.startswith("taesdxl_decoder."): - sdxl_taesd_dec = True - elif v.startswith("taesdxl_encoder."): - sdxl_taesd_enc = True - elif v.startswith("taesd3_decoder."): - sd3_taesd_dec = True - elif v.startswith("taesd3_encoder."): - sd3_taesd_enc = True - elif v.startswith("taef1_encoder."): - f1_taesd_dec = True - elif v.startswith("taef1_decoder."): - f1_taesd_enc = True - else: + parts = v.split("_", 1) + if len(parts) != 2 or parts[0] not in s.image_taes: for tae in s.video_taes: if v.startswith(tae): vaes.append(v) - - if sd1_taesd_dec and sd1_taesd_enc: - vaes.append("taesd") - if sdxl_taesd_dec and sdxl_taesd_enc: - vaes.append("taesdxl") - if sd3_taesd_dec and sd3_taesd_enc: - vaes.append("taesd3") - if f1_taesd_dec and f1_taesd_enc: - vaes.append("taef1") + break + continue + if parts[1].startswith("encoder."): + have_img_encoder.add(parts[0]) + elif parts[1].startswith("decoder."): + have_img_decoder.add(parts[0]) + vaes += [k for k in have_img_decoder if k in have_img_encoder] vaes.append("pixel_space") return vaes @@ -827,6 +803,11 @@ class VAELoader: else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + if vae_name == "taef2": + if metadata is None: + metadata = {"tae_latent_channels": 128} + else: + metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,)