Allow using Flux2 TAE is a normal VAE

pull/13496/head
blepping 2026-04-22 22:06:10 -06:00
parent f49398110a
commit 6e3c6087d4
2 changed files with 21 additions and 37 deletions

View File

@ -477,7 +477,10 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
self.latent_channels = metadata["tae_latent_channels"]
else:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()

View File

@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader):
class VAELoader: class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
@staticmethod @staticmethod
def vae_list(s): def vae_list(s):
vaes = folder_paths.get_filename_list("vae") vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx") approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False have_img_encoder, have_img_decoder = set(), set()
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes: for v in approx_vaes:
if v.startswith("taesd_decoder."): parts = v.split("_", 1)
sd1_taesd_dec = True if len(parts) != 2 or parts[0] not in s.image_taes:
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
elif v.startswith("taesd3_decoder."):
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
else:
for tae in s.video_taes: for tae in s.video_taes:
if v.startswith(tae): if v.startswith(tae):
vaes.append(v) vaes.append(v)
break
if sd1_taesd_dec and sd1_taesd_enc: continue
vaes.append("taesd") if parts[1].startswith("encoder."):
if sdxl_taesd_dec and sdxl_taesd_enc: have_img_encoder.add(parts[0])
vaes.append("taesdxl") elif parts[1].startswith("decoder."):
if sd3_taesd_dec and sd3_taesd_enc: have_img_decoder.add(parts[0])
vaes.append("taesd3") vaes += [k for k in have_img_decoder if k in have_img_encoder]
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
vaes.append("pixel_space") vaes.append("pixel_space")
return vaes return vaes
@ -827,6 +803,11 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
if vae_name == "taef2":
if metadata is None:
metadata = {"tae_latent_channels": 128}
else:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
return (vae,) return (vae,)