Fix LTXV Reference Audio node (#13531)
parent
abf3d56f27
commit
6fbb6b6f49
|
|
@ -1,6 +1,7 @@
|
||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
|
@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
# Encode reference audio to latents and patchify
|
# Encode reference audio to latents and patchify
|
||||||
audio_latents = audio_vae.encode(reference_audio)
|
sample_rate = reference_audio["sample_rate"]
|
||||||
|
vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100)
|
||||||
|
if vae_sample_rate != sample_rate:
|
||||||
|
waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate)
|
||||||
|
else:
|
||||||
|
waveform = reference_audio["waveform"]
|
||||||
|
|
||||||
|
audio_latents = audio_vae.encode(waveform.movedim(1, -1))
|
||||||
b, c, t, f = audio_latents.shape
|
b, c, t, f = audio_latents.shape
|
||||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||||
ref_audio = {"tokens": ref_tokens}
|
ref_audio = {"tokens": ref_tokens}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue