mirror of https://github.com/vladmandic/automatic
fix video vae selector
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/3892/head
parent
8c37607b82
commit
cfcfa4e4d2
|
|
@ -14,6 +14,7 @@
|
||||||
install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md)
|
install as any other extension and for details see extension [README](https://github.com/vladmandic/sd-extension-framepack/blob/main/README.md)
|
||||||
- I2V & FLF2V support with explicit strength controls
|
- I2V & FLF2V support with explicit strength controls
|
||||||
- complex actions: modify prompts for each section of the video
|
- complex actions: modify prompts for each section of the video
|
||||||
|
- decode: use local, tiny or remote VAE
|
||||||
- custom models: e.g. replace llama with one of your choice
|
- custom models: e.g. replace llama with one of your choice
|
||||||
- video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation
|
- video: multiple codecs and with hw acceleration, raw export, frame export, frame interpolation
|
||||||
- compute: quantization support, new offloading, more configuration options, cross-platform, etc.
|
- compute: quantization support, new offloading, more configuration options, cross-platform, etc.
|
||||||
|
|
@ -94,6 +95,7 @@
|
||||||
- extension installer handling of PYTHONPATH
|
- extension installer handling of PYTHONPATH
|
||||||
- trace logging
|
- trace logging
|
||||||
- api logging
|
- api logging
|
||||||
|
- video vae selection load correct vae
|
||||||
|
|
||||||
## Update for 2025-04-12
|
## Update for 2025-04-12
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
torch.__long_version__ = torch.__version__
|
torch.__long_version__ = torch.__version__
|
||||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
timer.startup.record("torch")
|
timer.startup.record("torch")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bitsandbytes # pylint: disable=W0611,C0411
|
import bitsandbytes # pylint: disable=W0611,C0411
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -73,9 +75,11 @@ timer.startup.record("pydantic")
|
||||||
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
|
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
|
||||||
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
|
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
|
||||||
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
|
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
|
||||||
|
|
||||||
import diffusers # pylint: disable=W0611,C0411
|
import diffusers # pylint: disable=W0611,C0411
|
||||||
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
|
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
|
||||||
import huggingface_hub # pylint: disable=W0611,C0411
|
import huggingface_hub # pylint: disable=W0611,C0411
|
||||||
|
|
||||||
logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR)
|
logging.getLogger("diffusers.loaders.single_file").setLevel(logging.ERROR)
|
||||||
timer.startup.record("diffusers")
|
timer.startup.record("diffusers")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
|
||||||
latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype)
|
latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype)
|
||||||
latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
|
latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
|
||||||
if model_type == 'hunyuanvideo':
|
if model_type == 'hunyuanvideo':
|
||||||
latent_copy = latent_copy.unsqueeze(0)
|
latent_copy = latent_copy.unsqueeze(0) if len(latents.shape) == 4 else latents
|
||||||
|
|
||||||
for i in range(latent_copy.shape[0]):
|
for i in range(latent_copy.shape[0]):
|
||||||
params = {}
|
params = {}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ prev_cls = ''
|
||||||
prev_type = ''
|
prev_type = ''
|
||||||
prev_model = ''
|
prev_model = ''
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
supported = ['sd', 'sdxl', 'f1', 'h1', 'hunyuanvideo', 'wanvideo', 'mochivideo']
|
||||||
|
|
||||||
|
|
||||||
def warn_once(msg, variant=None):
|
def warn_once(msg, variant=None):
|
||||||
|
|
@ -56,38 +57,37 @@ def get_model(model_type = 'decoder', variant = None):
|
||||||
cls = 'sd'
|
cls = 'sd'
|
||||||
if cls == 'h1': # hidream uses flux vae
|
if cls == 'h1': # hidream uses flux vae
|
||||||
cls = 'f1'
|
cls = 'f1'
|
||||||
|
if cls not in supported:
|
||||||
|
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||||
variant = variant or shared.opts.taesd_variant
|
variant = variant or shared.opts.taesd_variant
|
||||||
folder = os.path.join(paths.models_path, "TAESD")
|
folder = os.path.join(paths.models_path, "TAESD")
|
||||||
os.makedirs(folder, exist_ok=True)
|
os.makedirs(folder, exist_ok=True)
|
||||||
if 'video' in cls:
|
if variant.startswith('TAE'):
|
||||||
return None
|
|
||||||
if ('sd' not in cls) and ('f1' not in cls):
|
|
||||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
|
||||||
return None
|
|
||||||
if variant.startswith('TAESD'):
|
|
||||||
cfg = TAESD_MODELS[variant]
|
cfg = TAESD_MODELS[variant]
|
||||||
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
||||||
return cfg['model']
|
return cfg['model']
|
||||||
fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth')
|
fn = os.path.join(folder, cfg['fn'] + cls + '_' + model_type + '.pth')
|
||||||
if not os.path.exists(fn):
|
if not os.path.exists(fn):
|
||||||
uri = cfg['uri'] + '/tae' + cls + '_' + model_type + '.pth'
|
uri = cfg['uri']
|
||||||
|
if not uri.endswith('.pth'):
|
||||||
|
uri += '/tae' + cls + '_' + model_type + '.pth'
|
||||||
try:
|
try:
|
||||||
shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download')
|
shared.log.info(f'Decode: type="taesd" variant="{variant}": uri="{uri}" fn="{fn}" download')
|
||||||
torch.hub.download_url_to_file(uri, fn)
|
torch.hub.download_url_to_file(uri, fn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warn_once(f'download uri={uri} {e}')
|
warn_once(f'download uri={uri} {e}', variant=variant)
|
||||||
if os.path.exists(fn):
|
if os.path.exists(fn):
|
||||||
prev_cls = cls
|
prev_cls = cls
|
||||||
prev_type = model_type
|
prev_type = model_type
|
||||||
prev_model = variant
|
prev_model = variant
|
||||||
shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load')
|
shared.log.debug(f'Decode: type="taesd" variant="{variant}" fn="{fn}" load')
|
||||||
if 'TAEHV' in variant:
|
if 'TAE HunyuanVideo' in variant:
|
||||||
from modules.taesd.taehv import TAEHV
|
from modules.taesd.taehv import TAEHV
|
||||||
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
||||||
if 'TAEW2' in variant:
|
elif 'TAE WanVideo' in variant:
|
||||||
from modules.taesd.taehv import TAEHV
|
from modules.taesd.taehv import TAEHV
|
||||||
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
TAESD_MODELS[variant]['model'] = TAEHV(checkpoint_path=fn)
|
||||||
elif 'TAEM1' in variant:
|
elif 'TAE MochiVideo' in variant:
|
||||||
from modules.taesd.taem1 import TAEM1
|
from modules.taesd.taem1 import TAEM1
|
||||||
TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn)
|
TAESD_MODELS[variant]['model'] = TAEM1(checkpoint_path=fn)
|
||||||
else:
|
else:
|
||||||
|
|
@ -99,7 +99,7 @@ def get_model(model_type = 'decoder', variant = None):
|
||||||
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
if (cls == prev_cls) and (model_type == prev_type) and (variant == prev_model) and (cfg['model'] is not None):
|
||||||
return cfg['model']
|
return cfg['model']
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||||
return None
|
return None
|
||||||
repo = cfg['repo']
|
repo = cfg['repo']
|
||||||
prev_cls = cls
|
prev_cls = cls
|
||||||
|
|
@ -117,7 +117,7 @@ def get_model(model_type = 'decoder', variant = None):
|
||||||
CQYAN_MODELS[variant][cls]['model'] = vae
|
CQYAN_MODELS[variant][cls]['model'] = vae
|
||||||
return vae
|
return vae
|
||||||
else:
|
else:
|
||||||
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported')
|
warn_once(f'cls={shared.sd_model.__class__.__name__} type={cls} unsuppported', variant=variant)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,10 +36,10 @@ def vae_decode_tiny(latents):
|
||||||
shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported')
|
shared.log.warning(f'Video VAE: type=Tiny cls={shared.sd_model.__class__.__name__} not supported')
|
||||||
return None
|
return None
|
||||||
from modules import sd_vae_taesd
|
from modules import sd_vae_taesd
|
||||||
vae = sd_vae_taesd.get_model(variant)
|
vae = sd_vae_taesd.get_model(variant=variant)
|
||||||
if vae is None:
|
if vae is None:
|
||||||
return None
|
return None
|
||||||
debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}')
|
shared.log.debug(f'Video VAE: type=Tiny cls={vae.__class__.__name__} variant="{variant}" latents={latents.shape}')
|
||||||
vae = vae.to(device=devices.device, dtype=devices.dtype)
|
vae = vae.to(device=devices.device, dtype=devices.dtype)
|
||||||
latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype)
|
latents = latents.transpose(1, 2).to(device=devices.device, dtype=devices.dtype)
|
||||||
images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1)
|
images = vae.decode_video(latents, parallel=False).transpose(1, 2).mul_(2).sub_(1)
|
||||||
|
|
|
||||||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
||||||
Subproject commit d57624e45e2fa6f827e1e3134ee9a1a996476fd5
|
Subproject commit d9fcc40daad2a7d6ff768865e829f14235da4078
|
||||||
Loading…
Reference in New Issue