From cfcfa4e4d28dcce91eb029c8a70d01042161403f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 25 Apr 2025 21:05:56 -0400 Subject: [PATCH] fix video vae selector Signed-off-by: Vladimir Mandic --- CHANGELOG.md | 2 ++ modules/loader.py | 4 ++++ modules/sd_vae_remote.py | 2 +- modules/sd_vae_taesd.py | 26 +++++++++++++------------- modules/video_models/video_vae.py | 4 ++-- wiki | 2 +- 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dc920895..9ed0ba225 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md) - I2V & FLF2V support with explicit strength controls - complex actions: modify prompts for each section of the video + - decode: use local, tiny or remote VAE - custom models: e.g. replace llama with one of your choice - video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation - compute: quantization support, new offloading, more configuration options, cross-platform, etc. @@ -94,6 +95,7 @@ - extension installer handling of PYTHONPATH - trace logging - api logging + - video vae selection load correct vae ## Update for 2025-04-12 diff --git a/modules/loader.py b/modules/loader.py index 4fa25a44c..2574c90b4 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -42,6 +42,8 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) timer.startup.record("torch") + + try: import bitsandbytes # pylint: disable=W0611,C0411 except Exception: @@ -73,9 +75,11 @@ timer.startup.record("pydantic") import diffusers.utils.import_utils # pylint: disable=W0611,C0411 diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access + import diffusers # pylint: disable=W0611,C0411 import diffusers.loaders.single_file # pylint: disable=W0611,C0411 import huggingface_hub # pylint: disable=W0611,C0411 + logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR) timer.startup.record("diffusers") diff --git a/modules/sd_vae_remote.py b/modules/sd_vae_remote.py index 9153d5ebb..c751bc083 100644 --- a/modules/sd_vae_remote.py +++ b/modules/sd_vae_remote.py @@ -49,7 +49,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_ latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype) latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents if model_type == 'hunyuanvideo': - latent_copy = latent_copy.unsqueeze(0) + latent_copy = latent_copy.unsqueeze(0) if len(latents.shape) == 4 else latents for i in range(latent_copy.shape[0]): params = {} diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index a35c85a83..da12422f1 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -36,6 +36,7 @@ prev_cls = '' prev_type = '' prev_model = '' lock = threading.Lock() +supported = ['sd', 'sdxl', 'f1', 'h1', 'hunyuanvideo', 'wanvideo', 'mochivideo'] def warn_once(msg, variant=None): @@ -56,38 +57,37 @@ def get_model(model_type = 'decoder', variant = None): cls = 'sd' if cls == 'h1': # hidream uses flux vae cls = 'f1' + if cls not in supported: + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant) variant = variant or shared.opts.taesd_variant folder = os.path.join(paths.models_path, "TAESD") os.makedirs(folder, exist_ok=True) - if 'video' in cls: - return None - if ('sd' not in cls) and ('f1' not in cls): - warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') - return None - if variant.startswith('TAESD'): + if variant.startswith('TAE'): cfg = TAESD_MODELS[variant] if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None): return cfg['model'] fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth') if not os.path.exists(fn): - uri = cfg['uri'] + '/tae' + cls + '_' + model_type + '.pth' + uri = cfg['uri'] + if not uri.endswith('.pth'): + uri += '/tae' + cls + '_' + model_type + '.pth' try: shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download') torch.hub.download_url_to_file(uri, fn) except Exception as e: - warn_once(f'download uri={uri} {e}') + warn_once(f'download uri={uri} {e}', variant=variant) if os.path.exists(fn): prev_cls = cls prev_type = model_type prev_model = variant shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load') - if 'TAEHV' in variant: + if 'TAE HunyuanVideo' in variant: from modules.taesd.taehv import TAEHV TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn) - if 'TAEW2' in variant: + elif 'TAE WanVideo' in variant: from modules.taesd.taehv import TAEHV TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn) - elif 'TAEM1' in variant: + elif 'TAE MochiVideo' in variant: from modules.taesd.taem1 import TAEM1 TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn) else: @@ -99,7 +99,7 @@ def get_model(model_type = 'decoder', variant = None): if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None): return cfg['model'] if cfg is None: - warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant) return None repo = cfg['repo'] prev_cls = cls @@ -117,7 +117,7 @@ def get_model(model_type = 'decoder', variant = None): CQYAN_MODELS[variant][cls]['model'] = vae return vae else: - warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') + warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant) return None diff --git a/modules/video_models/video_vae.py b/modules/video_models/video_vae.py index 0dab30c72..fa5957664 100644 --- a/modules/video_models/video_vae.py +++ b/modules/video_models/video_vae.py @@ -36,10 +36,10 @@ def vae_decode_tiny(latents): shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported') return None from modules import sd_vae_taesd - vae = sd_vae_taesd.get_model(variant) + vae = sd_vae_taesd.get_model(variant=variant) if vae is None: return None - debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}') + shared.log.debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}') vae = vae.to(device=devices.device, dtype=devices.dtype) latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype) images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1) diff --git a/wiki b/wiki index d57624e45..d9fcc40da 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit d57624e45e2fa6f827e1e3134ee9a1a996476fd5 +Subproject commit d9fcc40daad2a7d6ff768865e829f14235da4078