Allow using Flux2 TAE is a normal VAE
parent
f49398110a
commit
6e3c6087d4
|
|
@ -477,7 +477,10 @@ class VAE:
|
||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
|
||||||
|
self.latent_channels = metadata["tae_latent_channels"]
|
||||||
|
else:
|
||||||
|
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
|
|
|
||||||
53
nodes.py
53
nodes.py
|
|
@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_list(s):
|
def vae_list(s):
|
||||||
vaes = folder_paths.get_filename_list("vae")
|
vaes = folder_paths.get_filename_list("vae")
|
||||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
sdxl_taesd_enc = False
|
have_img_encoder, have_img_decoder = set(), set()
|
||||||
sdxl_taesd_dec = False
|
|
||||||
sd1_taesd_enc = False
|
|
||||||
sd1_taesd_dec = False
|
|
||||||
sd3_taesd_enc = False
|
|
||||||
sd3_taesd_dec = False
|
|
||||||
f1_taesd_enc = False
|
|
||||||
f1_taesd_dec = False
|
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
parts = v.split("_", 1)
|
||||||
sd1_taesd_dec = True
|
if len(parts) != 2 or parts[0] not in s.image_taes:
|
||||||
elif v.startswith("taesd_encoder."):
|
|
||||||
sd1_taesd_enc = True
|
|
||||||
elif v.startswith("taesdxl_decoder."):
|
|
||||||
sdxl_taesd_dec = True
|
|
||||||
elif v.startswith("taesdxl_encoder."):
|
|
||||||
sdxl_taesd_enc = True
|
|
||||||
elif v.startswith("taesd3_decoder."):
|
|
||||||
sd3_taesd_dec = True
|
|
||||||
elif v.startswith("taesd3_encoder."):
|
|
||||||
sd3_taesd_enc = True
|
|
||||||
elif v.startswith("taef1_encoder."):
|
|
||||||
f1_taesd_dec = True
|
|
||||||
elif v.startswith("taef1_decoder."):
|
|
||||||
f1_taesd_enc = True
|
|
||||||
else:
|
|
||||||
for tae in s.video_taes:
|
for tae in s.video_taes:
|
||||||
if v.startswith(tae):
|
if v.startswith(tae):
|
||||||
vaes.append(v)
|
vaes.append(v)
|
||||||
|
break
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
continue
|
||||||
vaes.append("taesd")
|
if parts[1].startswith("encoder."):
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
have_img_encoder.add(parts[0])
|
||||||
vaes.append("taesdxl")
|
elif parts[1].startswith("decoder."):
|
||||||
if sd3_taesd_dec and sd3_taesd_enc:
|
have_img_decoder.add(parts[0])
|
||||||
vaes.append("taesd3")
|
vaes += [k for k in have_img_decoder if k in have_img_encoder]
|
||||||
if f1_taesd_dec and f1_taesd_enc:
|
|
||||||
vaes.append("taef1")
|
|
||||||
vaes.append("pixel_space")
|
vaes.append("pixel_space")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
|
|
@ -827,6 +803,11 @@ class VAELoader:
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
|
if vae_name == "taef2":
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {"tae_latent_channels": 128}
|
||||||
|
else:
|
||||||
|
metadata["tae_latent_channels"] = 128
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue