fix video vae selector

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3892/head
Vladimir Mandic 2025-04-25 21:05:56 -04:00
parent 8c37607b82
commit cfcfa4e4d2
6 changed files with 23 additions and 17 deletions

View File

@ -14,6 +14,7 @@
install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md) install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md)
- I2V & FLF2V support with explicit strength controls - I2V & FLF2V support with explicit strength controls
- complex actions: modify prompts for each section of the video - complex actions: modify prompts for each section of the video
- decode: use local, tiny or remote VAE
- custom models: e.g. replace llama with one of your choice - custom models: e.g. replace llama with one of your choice
- video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation - video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation
- compute: quantization support, new offloading, more configuration options, cross-platform, etc. - compute: quantization support, new offloading, more configuration options, cross-platform, etc.
@ -94,6 +95,7 @@
- extension installer handling of PYTHONPATH - extension installer handling of PYTHONPATH
- trace logging - trace logging
- api logging - api logging
- video vae selection load correct vae
## Update for 2025-04-12 ## Update for 2025-04-12

View File

@ -42,6 +42,8 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__ torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
timer.startup.record("torch") timer.startup.record("torch")
try: try:
import bitsandbytes # pylint: disable=W0611,C0411 import bitsandbytes # pylint: disable=W0611,C0411
except Exception: except Exception:
@ -73,9 +75,11 @@ timer.startup.record("pydantic")
import diffusers.utils.import_utils # pylint: disable=W0611,C0411 import diffusers.utils.import_utils # pylint: disable=W0611,C0411
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
import diffusers # pylint: disable=W0611,C0411 import diffusers # pylint: disable=W0611,C0411
import diffusers.loaders.single_file # pylint: disable=W0611,C0411 import diffusers.loaders.single_file # pylint: disable=W0611,C0411
import huggingface_hub # pylint: disable=W0611,C0411 import huggingface_hub # pylint: disable=W0611,C0411
logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR) logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR)
timer.startup.record("diffusers") timer.startup.record("diffusers")

View File

@ -49,7 +49,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype) latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype)
latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
if model_type == 'hunyuanvideo': if model_type == 'hunyuanvideo':
latent_copy = latent_copy.unsqueeze(0) latent_copy = latent_copy.unsqueeze(0) if len(latents.shape) == 4 else latents
for i in range(latent_copy.shape[0]): for i in range(latent_copy.shape[0]):
params = {} params = {}

View File

@ -36,6 +36,7 @@ prev_cls = ''
prev_type = '' prev_type = ''
prev_model = '' prev_model = ''
lock = threading.Lock() lock = threading.Lock()
supported = ['sd', 'sdxl', 'f1', 'h1', 'hunyuanvideo', 'wanvideo', 'mochivideo']
def warn_once(msg, variant=None): def warn_once(msg, variant=None):
@ -56,38 +57,37 @@ def get_model(model_type = 'decoder', variant = None):
cls = 'sd' cls = 'sd'
if cls == 'h1': # hidream uses flux vae if cls == 'h1': # hidream uses flux vae
cls = 'f1' cls = 'f1'
if cls not in supported:
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
variant = variant or shared.opts.taesd_variant variant = variant or shared.opts.taesd_variant
folder = os.path.join(paths.models_path, "TAESD") folder = os.path.join(paths.models_path, "TAESD")
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
if 'video' in cls: if variant.startswith('TAE'):
return None
if ('sd' not in cls) and ('f1' not in cls):
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
return None
if variant.startswith('TAESD'):
cfg = TAESD_MODELS[variant] cfg = TAESD_MODELS[variant]
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None): if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
return cfg['model'] return cfg['model']
fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth') fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth')
if not os.path.exists(fn): if not os.path.exists(fn):
uri = cfg['uri'] + '/tae' + cls + '_' + model_type + '.pth' uri = cfg['uri']
if not uri.endswith('.pth'):
uri += '/tae' + cls + '_' + model_type + '.pth'
try: try:
shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download') shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download')
torch.hub.download_url_to_file(uri, fn) torch.hub.download_url_to_file(uri, fn)
except Exception as e: except Exception as e:
warn_once(f'download uri={uri} {e}') warn_once(f'download uri={uri} {e}', variant=variant)
if os.path.exists(fn): if os.path.exists(fn):
prev_cls = cls prev_cls = cls
prev_type = model_type prev_type = model_type
prev_model = variant prev_model = variant
shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load') shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load')
if 'TAEHV' in variant: if 'TAE HunyuanVideo' in variant:
from modules.taesd.taehv import TAEHV from modules.taesd.taehv import TAEHV
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn) TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
if 'TAEW2' in variant: elif 'TAE WanVideo' in variant:
from modules.taesd.taehv import TAEHV from modules.taesd.taehv import TAEHV
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn) TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
elif 'TAEM1' in variant: elif 'TAE MochiVideo' in variant:
from modules.taesd.taem1 import TAEM1 from modules.taesd.taem1 import TAEM1
TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn) TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn)
else: else:
@ -99,7 +99,7 @@ def get_model(model_type = 'decoder', variant = None):
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None): if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
return cfg['model'] return cfg['model']
if cfg is None: if cfg is None:
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
return None return None
repo = cfg['repo'] repo = cfg['repo']
prev_cls = cls prev_cls = cls
@ -117,7 +117,7 @@ def get_model(model_type = 'decoder', variant = None):
CQYAN_MODELS[variant][cls]['model'] = vae CQYAN_MODELS[variant][cls]['model'] = vae
return vae return vae
else: else:
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported') warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
return None return None

View File

@ -36,10 +36,10 @@ def vae_decode_tiny(latents):
shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported') shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported')
return None return None
from modules import sd_vae_taesd from modules import sd_vae_taesd
vae = sd_vae_taesd.get_model(variant) vae = sd_vae_taesd.get_model(variant=variant)
if vae is None: if vae is None:
return None return None
debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}') shared.log.debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}')
vae = vae.to(device=devices.device, dtype=devices.dtype) vae = vae.to(device=devices.device, dtype=devices.dtype)
latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype) latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype)
images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1) images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1)

2
wiki

@ -1 +1 @@
Subproject commit d57624e45e2fa6f827e1e3134ee9a1a996476fd5 Subproject commit d9fcc40daad2a7d6ff768865e829f14235da4078