mirror of https://github.com/vladmandic/automatic
fix video vae selector
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/3892/head
parent
8c37607b82
commit
cfcfa4e4d2
|
|
@ -14,6 +14,7 @@
|
|||
install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md)
|
||||
- I2V & FLF2V support with explicit strength controls
|
||||
- complex actions: modify prompts for each section of the video
|
||||
- decode: use local, tiny or remote VAE
|
||||
- custom models: e.g. replace llama with one of your choice
|
||||
- video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation
|
||||
- compute: quantization support, new offloading, more configuration options, cross-platform, etc.
|
||||
|
|
@ -94,6 +95,7 @@
|
|||
- extension installer handling of PYTHONPATH
|
||||
- trace logging
|
||||
- api logging
|
||||
- video vae selection load correct vae
|
||||
|
||||
## Update for 2025-04-12
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
|||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
timer.startup.record("torch")
|
||||
|
||||
|
||||
try:
|
||||
import bitsandbytes # pylint: disable=W0611,C0411
|
||||
except Exception:
|
||||
|
|
@ -73,9 +75,11 @@ timer.startup.record("pydantic")
|
|||
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
|
||||
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
|
||||
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
|
||||
|
||||
import diffusers # pylint: disable=W0611,C0411
|
||||
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
|
||||
import huggingface_hub # pylint: disable=W0611,C0411
|
||||
|
||||
logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR)
|
||||
timer.startup.record("diffusers")
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
|
|||
latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype)
|
||||
latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
|
||||
if model_type == 'hunyuanvideo':
|
||||
latent_copy = latent_copy.unsqueeze(0)
|
||||
latent_copy = latent_copy.unsqueeze(0) if len(latents.shape) == 4 else latents
|
||||
|
||||
for i in range(latent_copy.shape[0]):
|
||||
params = {}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ prev_cls = ''
|
|||
prev_type = ''
|
||||
prev_model = ''
|
||||
lock = threading.Lock()
|
||||
supported = ['sd', 'sdxl', 'f1', 'h1', 'hunyuanvideo', 'wanvideo', 'mochivideo']
|
||||
|
||||
|
||||
def warn_once(msg, variant=None):
|
||||
|
|
@ -56,38 +57,37 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
cls = 'sd'
|
||||
if cls == 'h1': # hidream uses flux vae
|
||||
cls = 'f1'
|
||||
if cls not in supported:
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||
variant = variant or shared.opts.taesd_variant
|
||||
folder = os.path.join(paths.models_path, "TAESD")
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
if 'video' in cls:
|
||||
return None
|
||||
if ('sd' not in cls) and ('f1' not in cls):
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
||||
return None
|
||||
if variant.startswith('TAESD'):
|
||||
if variant.startswith('TAE'):
|
||||
cfg = TAESD_MODELS[variant]
|
||||
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
||||
return cfg['model']
|
||||
fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth')
|
||||
if not os.path.exists(fn):
|
||||
uri = cfg['uri'] + '/tae' + cls + '_' + model_type + '.pth'
|
||||
uri = cfg['uri']
|
||||
if not uri.endswith('.pth'):
|
||||
uri += '/tae' + cls + '_' + model_type + '.pth'
|
||||
try:
|
||||
shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download')
|
||||
torch.hub.download_url_to_file(uri, fn)
|
||||
except Exception as e:
|
||||
warn_once(f'download uri={uri} {e}')
|
||||
warn_once(f'download uri={uri} {e}', variant=variant)
|
||||
if os.path.exists(fn):
|
||||
prev_cls = cls
|
||||
prev_type = model_type
|
||||
prev_model = variant
|
||||
shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load')
|
||||
if 'TAEHV' in variant:
|
||||
if 'TAE HunyuanVideo' in variant:
|
||||
from modules.taesd.taehv import TAEHV
|
||||
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
||||
if 'TAEW2' in variant:
|
||||
elif 'TAE WanVideo' in variant:
|
||||
from modules.taesd.taehv import TAEHV
|
||||
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
||||
elif 'TAEM1' in variant:
|
||||
elif 'TAE MochiVideo' in variant:
|
||||
from modules.taesd.taem1 import TAEM1
|
||||
TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn)
|
||||
else:
|
||||
|
|
@ -99,7 +99,7 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
||||
return cfg['model']
|
||||
if cfg is None:
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||
return None
|
||||
repo = cfg['repo']
|
||||
prev_cls = cls
|
||||
|
|
@ -117,7 +117,7 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
CQYAN_MODELS[variant][cls]['model'] = vae
|
||||
return vae
|
||||
else:
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ def vae_decode_tiny(latents):
|
|||
shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported')
|
||||
return None
|
||||
from modules import sd_vae_taesd
|
||||
vae = sd_vae_taesd.get_model(variant)
|
||||
vae = sd_vae_taesd.get_model(variant=variant)
|
||||
if vae is None:
|
||||
return None
|
||||
debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}')
|
||||
shared.log.debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}')
|
||||
vae = vae.to(device=devices.device, dtype=devices.dtype)
|
||||
latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype)
|
||||
images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit d57624e45e2fa6f827e1e3134ee9a1a996476fd5
|
||||
Subproject commit d9fcc40daad2a7d6ff768865e829f14235da4078
|
||||
Loading…
Reference in New Issue