Merge branch 'master' into range-type

pull/13322/head
guill 2026-04-23 15:20:40 -07:00 committed by GitHub
commit 4a8ada2d15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 7370 additions and 420 deletions

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@ -139,9 +139,9 @@ Example:
"_quantization_metadata": { "_quantization_metadata": {
"format_version": "1.0", "format_version": "1.0",
"layers": { "layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn", "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": "float8_e4m3fn", "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": "float8_e4m3fn" "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
} }
} }
} }
@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics 3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
#### Alternative Downloads: #### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) [Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).

View File

@ -67,7 +67,7 @@ class InternalRoutes:
(entry for entry in os.scandir(directory) if is_visible_file(entry)), (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime key=lambda entry: -entry.stat().st_mtime
) )
return web.json_response([entry.name for entry in sorted_files], status=200) return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)
def get_app(self): def get_app(self):

View File

@ -182,7 +182,7 @@
] ]
}, },
"widgets_values": [ "widgets_values": [
50 0
] ]
}, },
{ {

View File

@ -316,7 +316,7 @@
"step": 1 "step": 1
}, },
"widgets_values": [ "widgets_values": [
30 0
] ]
}, },
{ {

301
comfy/ldm/ernie/model.py Normal file
View File

@ -0,0 +1,301 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
if not comfy.model_management.supports_fp64(pos.device):
device = torch.device("cpu")
else:
device = pos.device
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = freqs_cis[0]
sin_ = freqs_cis[1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: tuple):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
super().__init__()
self.patch_size = patch_size
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
if self.flip_sin_to_cos:
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
else:
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
super().__init__()
Linear = operations.Linear
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class ErnieImageAttention(nn.Module):
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
self.heads = heads
self.head_dim = dim_head
self.inner_dim = heads * dim_head
Linear = operations.Linear
RMSNorm = operations.RMSNorm
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
B, S, _ = x.shape
q_flat = self.to_q(x)
k_flat = self.to_k(x)
v_flat = self.to_v(x)
query = q_flat.view(B, S, self.heads, self.head_dim)
key = k_flat.view(B, S, self.heads, self.head_dim)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
return self.to_out[0](hidden_states)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
super().__init__()
Linear = operations.Linear
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
RMSNorm = operations.RMSNorm
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
eps=eps,
operations=operations,
device=device,
dtype=dtype
)
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x_norm = self.adaLN_sa_ln(x)
x_norm = x_norm * (1 + scale_msa) + shift_msa
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
x = residual + gate_msa * attn_out
residual = x
x_norm = self.adaLN_mlp_ln(x)
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
return residual + gate_mlp * self.mlp(x_norm)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
LayerNorm = operations.LayerNorm
Linear = operations.Linear
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
return x
class ErnieImageModel(nn.Module):
def __init__(
self,
hidden_size: int = 4096,
num_attention_heads: int = 32,
num_layers: int = 36,
ffn_hidden_size: int = 12288,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 3072,
rope_theta: int = 256,
rope_axes_dim: tuple = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
device=None,
dtype=None,
operations=None,
**kwargs
):
super().__init__()
self.dtype = dtype
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
self.out_channels = out_channels
Linear = operations.Linear
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
)
self.layers = nn.ModuleList([
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
for _ in range(num_layers)
])
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
def forward(self, x, timesteps, context, **kwargs):
device, dtype = x.device, x.dtype
B, C, H, W = x.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_bsh = self.x_embedder(x)
text_bth = context
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
index = float(Tmax)
transformer_options = kwargs.get("transformer_options", {})
rope_options = transformer_options.get("rope_options", None)
h_len, w_len = float(Hp), float(Wp)
h_offset, w_offset = 0.0, 0.0
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
image_ids[:, :, 0] = image_ids[:, :, 1] + index
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
del image_ids, text_ids
sample = self.time_proj(timesteps).to(dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
for layer in self.layers:
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
patches = self.final_linear(hidden_states)[:, :N_img, :]
output = (
patches.view(B, Hp, Wp, p, p, self.out_channels)
.permute(0, 5, 1, 3, 2, 4)
.contiguous()
.view(B, self.out_channels, H, W)
)
return output

View File

@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0 assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): if not comfy.model_management.supports_fp64(pos.device):
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = pos.device device = pos.device

View File

@ -4,9 +4,6 @@ import math
import torch import torch
import torchaudio import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
return cls(autoencoder=audio_config, vocoder=vocoder_config) return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer: class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout.""" """Applies per-channel statistics in patch space and restores original layout."""
@ -132,23 +105,17 @@ class AudioPreprocessor:
class AudioVAE(torch.nn.Module): class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points.""" """High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict): def __init__(self, metadata: dict):
super().__init__() super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata) component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder: if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder) self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else: else:
self.vocoder = Vocoder(config=component_config.vocoder) self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config() autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer( self.normalizer = AudioLatentNormalizer(
AudioPatchifier( AudioPatchifier(
@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"], n_fft=autoencoder_config["n_fft"],
) )
self.device_manager = ModelDeviceManager(self) def encode(self, audio, sample_rate=44100) -> torch.Tensor:
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors.""" """Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"] waveform = audio
waveform_sample_rate = audio["sample_rate"] waveform_sample_rate = sample_rate
input_device = waveform.device input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels: if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1: if waveform.shape[1] == 1:
@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
) )
mel_spec = self.preprocessor.waveform_to_mel( mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device waveform, waveform_sample_rate, device=waveform.device
) )
latents = self.autoencoder.encode(mel_spec) latents = self.autoencoder.encode(mel_spec)
@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform.""" """Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents) latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape) target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec) waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform) return waveform
def target_shape_from_latents(self, latents_shape): def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape batch, _, time, _ = latents_shape

View File

@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts: for layer in ts:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
if isinstance(layer, VideoResBlock): if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator) x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock): elif isinstance(layer, TimestepBlock):
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample): elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape) x = layer(x, output_shape=output_shape)
else: else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x) x = layer(x)
return x return x
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')
if "middle_block_after_patch" in transformer_patches:
patch = transformer_patches["middle_block_after_patch"]
for p in patch:
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
"timesteps": timesteps, "transformer_options": transformer_options})
h = out["h"]
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
for p in patch: for p in patch:
h, hsp = p(h, hsp, transformer_options) h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1) if hsp is not None:
del hsp h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0: if len(hs) > 0:
output_shape = hs[-1].shape output_shape = hs[-1].shape
else: else:

View File

@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module):
origin_max = np.max(hm[k]) origin_max = np.max(hm[k])
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
dr[border:-border, border:-border] = hm[k].copy() dr[border:-border, border:-border] = hm[k].copy()
dr = gaussian_filter(dr, sigma=2.0) dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
hm[k] = dr[border:-border, border:-border].copy() hm[k] = dr[border:-border, border:-border].copy()
cur_max = np.max(hm[k]) cur_max = np.max(hm[k])
if cur_max > 0: if cur_max > 0:

596
comfy/ldm/sam3/detector.py Normal file
View File

@ -0,0 +1,596 @@
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import roi_align
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
from comfy.ops import cast_to_input
def box_cxcywh_to_xyxy(x):
cx, cy, w, h = x.unbind(-1)
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
def gen_sineembed_for_position(pos_tensor, num_feats=256):
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
assert num_feats % 2 == 0
hdim = num_feats // 2
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
embeds = []
for c in range(pos_tensor.shape[-1]):
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
class SplitMHA(nn.Module):
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def forward(self, q_input, k_input=None, v_input=None, mask=None):
q = self.q_proj(q_input)
if k_input is None:
k = self.k_proj(q_input)
v = self.v_proj(q_input)
else:
k = self.k_proj(k_input)
v = self.v_proj(v_input if v_input is not None else k_input)
if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
return self.out_proj(out)
class MLPWithNorm(nn.Module):
"""MLP with residual connection and output LayerNorm."""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
])
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
self.residual = residual and (input_dim == output_dim)
def forward(self, x):
orig = x
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
if self.residual:
x = x + orig
return self.out_norm(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x, pos, text_memory=None, text_mask=None):
normed = self.norm1(x)
q_k = normed + pos
x = x + self.self_attn(q_k, q_k, normed)
if text_memory is not None:
normed = self.norm2(x)
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
normed = self.norm3(x)
x = x + self.linear2(F.relu(self.linear1(normed)))
return x
class TransformerEncoder(nn.Module):
"""Checkpoint: transformer.encoder.layers.N.*"""
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, x, pos, text_memory=None, text_mask=None):
for layer in self.layers:
x = layer(x, pos, text_memory, text_mask)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
q_k = x + x_pos
x = self.norm2(x + self.self_attn(q_k, q_k, x))
if text_memory is not None:
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
return x
class TransformerDecoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.num_queries = num_queries
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
@staticmethod
def _inverse_sigmoid(x):
return torch.log(x / (1 - x + 1e-6) + 1e-6)
def _compute_box_rpb(self, ref_points, H, W):
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
B, Q, _ = boxes_xyxy.shape
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
log2_8 = float(math.log2(8))
def log_scale(d):
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
return torch.cat([pres_bias, bias], dim=2)
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
B = memory.shape[0]
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
for layer_idx, layer in enumerate(self.layers):
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
if layer_idx < len(self.layers) - 1:
ref_inv = self._inverse_sigmoid(ref_points)
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
query_out = self.norm(tgt)
ref_inv = self._inverse_sigmoid(ref_points)
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
class Transformer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
class GeometryEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.roi_size = roi_size
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.encode = nn.ModuleList([
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def _encode_points(self, coords, labels, img_feat_2d):
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
B, N, _ = coords.shape
embed = self.points_direct_project(coords)
# Pool features from backbone at point locations via grid_sample
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
# Positional encoding of coordinates
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def _encode_boxes(self, boxes, labels, img_feat_2d):
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
B, N, _ = boxes.shape
embed = self.boxes_direct_project(boxes)
# ROI align from backbone at box regions
H, W = img_feat_2d.shape[-2:]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
boxes_scaled = boxes_xyxy * scale
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
embed = embed + proj
# Positional encoding of box center + size
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
enc = enc.view(B, N, -1)
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def forward(self, points=None, boxes=None, image_features=None):
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
# Prepare 2D image features for pooling
img_feat_2d = None
if image_features is not None:
B = image_features.shape[0]
HW, C = image_features.shape[1], image_features.shape[2]
hw = int(math.sqrt(HW))
img_normed = self.img_pre_norm(image_features)
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
embeddings = []
if points is not None:
coords, labels = points
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
if boxes is not None:
B = boxes.shape[0]
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
if not embeddings:
return None
geo = torch.cat(embeddings, dim=1)
geo = self.norm(geo)
if image_features is not None:
for layer in self.encode:
geo = layer(geo, torch.zeros_like(geo), image_features)
geo = self.encode_norm(geo)
return self.final_proj(geo)
class PixelDecoder(nn.Module):
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
def forward(self, backbone_features):
prev = backbone_features[-1]
for i, feat in enumerate(backbone_features[:-1][::-1]):
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
return prev
class MaskPredictor(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
def forward(self, query_embeddings, pixel_features):
mask_embed = self.mask_embed(query_embeddings)
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
class SegmentationHead(nn.Module):
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
if encoder_hidden_states is not None and prompt is not None:
enc_normed = self.cross_attn_norm(encoder_hidden_states)
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
encoder_hidden_states = enc_cross + encoder_hidden_states
if encoder_hidden_states is not None:
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
backbone_features = list(backbone_features)
backbone_features[-1] = encoder_visual
pixel_features = self.pixel_decoder(backbone_features)
instance_features = self.instance_seg_head(pixel_features)
masks = self.mask_predictor(query_embeddings, instance_features)
return masks
class DotProductScoring(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
self.scale = 1.0 / (d_model ** 0.5)
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
prompt = self.prompt_mlp(prompt_embeddings)
if prompt_mask is not None:
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
else:
pooled = prompt.mean(dim=1)
hs = self.hs_proj(query_embeddings)
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
scores = torch.matmul(hs, pp)
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
class SAM3Detector(nn.Module):
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
image_model = kwargs.pop("image_model", "SAM3")
for k in ("num_heads", "num_head_channels"):
kwargs.pop(k, None)
multiplex = image_model == "SAM31"
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
self.scalp = 0 if multiplex else 1
self.backbone = nn.ModuleDict({
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
})
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
def _get_backbone_features(self, images):
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
bb = self.backbone["vision_backbone"]
if bb.multiplex:
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
else:
all_f, all_p, tf, tp = bb(images, need_tracker=True)
return all_f, all_p, tf, tp
@staticmethod
def _run_geo_layer(layer, x, memory, memory_pos):
x = x + layer.self_attn(layer.norm1(x))
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
return x
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
points=None, boxes=None):
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
B = features[0].shape[0]
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
seg_features = features
if self.scalp > 0:
features = features[:-self.scalp]
positions = positions[:-self.scalp]
enc_feat, enc_pos = features[-1], positions[-1]
_, _, H, W = enc_feat.shape
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
has_prompts = text_embeddings is not None or points is not None or boxes is not None
if has_prompts:
geo_enc = self.geometry_encoder
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
for layer in geo_enc.encode:
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
geo_cls = geo_enc.encode_norm(geo_cls)
if text_embeddings is not None and text_embeddings.shape[0] != B:
text_embeddings = text_embeddings.expand(B, -1, -1)
if text_mask is not None and text_mask.shape[0] != B:
text_mask = text_mask.expand(B, -1)
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
text_embeddings = torch.cat(parts, dim=1)
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
if text_mask is not None:
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
else:
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
if text_embeddings is not None:
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
else:
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
features, positions, _, _ = self._get_backbone_features(images)
if text_embeddings is not None:
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, dec_out = self._detect(
features, positions, text_embeddings, text_mask, points, boxes)
if orig_size is not None:
oh, ow = orig_size
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
return {
"boxes": boxes_xyxy,
"scores": scores,
"masks": masks,
"presence": dec_out.get("presence"),
}
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
"""Run detection using a pre-computed ViTDet trunk output.
text_embeddings must already be resized through language_backbone.resizer.
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
"""
bb = self.backbone["vision_backbone"]
features = [conv(trunk_out) for conv in bb.convs]
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
class SAM3Model(nn.Module):
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
image_model = kwargs.get("image_model", "SAM3")
tracker_cls = TRACKER_CLASSES[image_model]
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
def forward(self, images, **kwargs):
return self.detector(images, **kwargs)
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
Args:
images: [B, 3, 1008, 1008] preprocessed images
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
mask_inputs: [B, 1, H, W] coarse mask logits to refine
Returns:
[B, 1, image_size, image_size] high-res mask logits
"""
bb = self.detector.backbone["vision_backbone"]
if bb.multiplex:
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
else:
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
if self.detector.scalp > 0:
tracker_features = tracker_features[:-self.detector.scalp]
tracker_positions = tracker_positions[:-self.detector.scalp]
high_res = list(tracker_features[:-1])
backbone_feat = tracker_features[-1]
B, C, H, W = backbone_feat.shape
# Add no-memory embedding (init frame path)
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
if no_mem is None:
no_mem = getattr(self.tracker, 'no_mem_embed', None)
if no_mem is not None:
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
backbone_features=backbone_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
box_inputs=box_inputs,
high_res_features=high_res,
multimask_output=(0 < num_pts <= 1),
)
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
def backbone_fn(frame, frame_idx=None):
trunk_out = bb.trunk(frame)
if bb.multiplex:
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
else:
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
return tf, tp, trunk_out
detect_fn = None
if text_prompts:
resizer = self.detector.backbone["language_backbone"]["resizer"]
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
def detect_fn(trunk_out):
all_scores, all_masks = [], []
for emb, mask in resized:
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
all_scores.append(det["scores"])
all_masks.append(det["masks"])
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
if hasattr(self.tracker, 'track_video_with_detection'):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)

425
comfy/ldm/sam3/sam.py Normal file
View File

@ -0,0 +1,425 @@
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.layers import EmbedND
from comfy.ops import cast_to_input
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
return torch.sigmoid(x) if self.sigmoid_output else x
class SAMAttention(nn.Module):
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
internal_dim = embedding_dim // downsample_rate
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module):
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.skip_first_layer_pe = skip_first_layer_pe
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, queries, keys, query_pe, key_pe):
if self.skip_first_layer_pe:
queries = self.norm1(self.self_attn(queries, queries, queries))
else:
q = queries + query_pe
queries = self.norm1(queries + self.self_attn(q, q, queries))
q, k = queries + query_pe, keys + key_pe
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
queries = self.norm3(queries + self.mlp(queries))
q, k = queries + query_pe, keys + key_pe
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
return queries, keys
class TwoWayTransformer(nn.Module):
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
for i in range(depth)
])
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, image_embedding, image_pe, point_embedding):
queries, keys = point_embedding, image_embedding
for layer in self.layers:
queries, keys = layer(queries, keys, point_embedding, image_pe)
q, k = queries + point_embedding, keys + image_pe
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
return queries, keys
class PositionEmbeddingRandom(nn.Module):
"""Fourier feature positional encoding with random gaussian projection."""
def __init__(self, num_pos_feats=64, scale=None):
super().__init__()
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
def _encode(self, normalized_coords):
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
orig_dtype = normalized_coords.dtype
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
def forward(self, size, device=None):
h, w = size
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
def forward_with_coords(self, pixel_coords, image_size):
norm = pixel_coords.clone()
norm[:, :, 0] /= image_size[1]
norm[:, :, 1] /= image_size[0]
return self._encode(norm)
# ViTDet backbone + FPN neck
def window_partition(x: torch.Tensor, window_size: int):
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
ids = torch.stack([(t % end_x) * scale_pos,
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
class _ViTMLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
"""ViTDet multi-head attention with fused QKV projection."""
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.use_rope = use_rope
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs_cis=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
def forward(self, x, freqs_cis=None):
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
B, H, W, C = x.shape
x = x.view(B, H * W, C)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(B, H, W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
super().__init__()
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.proj(x)
class ViTDet(nn.Module):
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.global_att_blocks = set(global_att_blocks)
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
grid_size = img_size // patch_size
pretrain_grid = pretrain_img_size // patch_size
self.blocks = nn.ModuleList()
for i in range(depth):
is_global = i in self.global_att_blocks
self.blocks.append(Block(
embed_dim, num_heads, mlp_ratio, qkv_bias,
window_size=0 if is_global else window_size,
use_rope=use_rope,
device=device, dtype=dtype, operations=operations,
))
if use_rope:
rope_scale = pretrain_grid / grid_size
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
else:
self.freqs_cis = None
self.freqs_cis_window = None
def _get_pos_embed(self, num_tokens):
pos = self.pos_embed
if pos.shape[1] == num_tokens:
return pos
cls_pos = pos[:, :1]
spatial_pos = pos[:, 1:]
old_size = int(math.sqrt(spatial_pos.shape[1]))
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
tiles_h = new_size // old_size + 1
tiles_w = new_size // old_size + 1
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
return torch.cat([cls_pos, tiled], dim=1)
def forward(self, x):
x = self.patch_embed(x)
B, C, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
x = x + pos[:, 1:Hp * Wp + 1]
x = x.view(B, Hp, Wp, C)
x = self.ln_pre(x)
freqs_cis_global = self.freqs_cis
freqs_cis_win = self.freqs_cis_window
if freqs_cis_global is not None:
freqs_cis_global = cast_to_input(freqs_cis_global, x)
if freqs_cis_win is not None:
freqs_cis_win = cast_to_input(freqs_cis_win, x)
for block in self.blocks:
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
x = block(x, freqs_cis=fc)
return x.permute(0, 3, 1, 2)
class FPNScaleConv(nn.Module):
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
super().__init__()
if scale == 4.0:
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 4
elif scale == 2.0:
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 2
elif scale == 1.0:
proj_in = in_dim
elif scale == 0.5:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
proj_in = in_dim
self.scale = scale
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
def forward(self, x):
if self.scale == 4.0:
x = F.gelu(self.dconv_2x2_0(x))
x = self.dconv_2x2_1(x)
elif self.scale == 2.0:
x = self.dconv_2x2(x)
elif self.scale == 0.5:
x = self.pool(x)
x = self.conv_1x1(x)
x = self.conv_3x3(x)
return x
class PositionEmbeddingSine(nn.Module):
"""2D sinusoidal position encoding (DETR-style) with result caching."""
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
super().__init__()
assert num_pos_feats % 2 == 0
self.half_dim = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
self._cache = {}
def _sincos(self, vals):
"""Encode 1D values to interleaved sin/cos features."""
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
raw = vals[..., None] * self.scale / freqs
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
def _encode_xy(self, x, y):
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
pos_x = x[:, None] * self.scale / dim_t
pos_y = y[:, None] * self.scale / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, cx, cy, w, h):
"""Encode box center + size to [N, d_model+2] features."""
pos_x, pos_y = self._encode_xy(cx, cy)
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
def forward(self, x):
B, C, H, W = x.shape
key = (H, W, x.device)
if key not in self._cache:
gy = torch.arange(H, dtype=torch.float32, device=x.device)
gx = torch.arange(W, dtype=torch.float32, device=x.device)
if self.normalize:
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
return self._cache[key].expand(B, -1, -1, -1)
class SAM3VisionBackbone(nn.Module):
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.multiplex = multiplex
fpn_args = dict(device=device, dtype=dtype, operations=operations)
if multiplex:
scales = [4.0, 2.0, 1.0]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
else:
scales = [4.0, 2.0, 1.0, 0.5]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
if tracker_only:
# Skip detector FPN when only tracker features are needed (video tracking)
if self.multiplex:
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
else:
tracker_convs = self.sam2_convs
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return None, None, tracker_features, tracker_positions
features = [conv(backbone_out) for conv in self.convs]
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
if self.multiplex:
if tracker_mode == "propagation":
tracker_convs = self.propagation_convs
elif tracker_mode == "interactive":
tracker_convs = self.interactive_convs
else:
return features, positions, None, None
elif need_tracker:
tracker_convs = self.sam2_convs
else:
return features, positions, None, None
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return features, positions, tracker_features, tracker_positions

1785
comfy/ldm/sam3/tracker.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,226 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
from comfy.ldm.modules.attention import optimized_attention
class ZeroSFT(nn.Module):
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
super().__init__()
ks = 3
pw = ks // 2
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
nhidden = 128
self.mlp_shared = nn.Sequential(
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
nn.SiLU()
)
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
self.pre_concat = bool(concat_channels != 0)
def forward(self, c, h, h_ori=None, control_scale=1):
if h_ori is not None and self.pre_concat:
h_raw = torch.cat([h_ori, h], dim=1)
else:
h_raw = h
h = h + self.zero_conv(c)
if h_ori is not None and self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
actv = self.mlp_shared(c)
gamma = self.zero_mul(actv)
beta = self.zero_add(actv)
h = self.param_free_norm(h)
h = torch.addcmul(h + beta, h, gamma)
if h_ori is not None and not self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
return torch.lerp(h_raw, h, control_scale)
class _CrossAttnInner(nn.Module):
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
)
def forward(self, x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
return self.to_out(optimized_attention(q, k, v, self.heads))
class ZeroCrossAttn(nn.Module):
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
super().__init__()
heads = query_dim // 64
dim_head = 64
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
def forward(self, context, x, control_scale=1):
b, c, h, w = x.shape
x_in = x
x = self.attn(
self.norm1(x).flatten(2).transpose(1, 2),
self.norm2(context).flatten(2).transpose(1, 2),
).transpose(1, 2).unflatten(2, (h, w))
return x_in + x * control_scale
class GLVControl(nn.Module):
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
def __init__(
self,
in_channels=4,
model_channels=320,
num_res_blocks=2,
attention_resolutions=(4, 2),
channel_mult=(1, 2, 4),
num_head_channels=64,
transformer_depth=(1, 2, 10),
context_dim=2048,
adm_in_channels=2816,
use_linear_in_transformer=True,
use_checkpoint=False,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.model_channels = model_channels
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
self.label_emb = nn.Sequential(
nn.Sequential(
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
])
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(num_res_blocks):
layers = [
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
dtype=dtype, device=device, operations=operations)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_heads = ch // num_head_channels
layers.append(
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[level], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
)
)
ds *= 2
num_heads = ch // num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[-1], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations),
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
)
self.input_hint_block = TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) + self.label_emb(y)
guided_hint = self.input_hint_block(x, emb, context)
hs = []
h = xt
for module in self.input_blocks:
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
hs.append(h)
return hs
class SUPIR(nn.Module):
"""
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
State dict keys match the original SUPIR checkpoint layout:
control_model.* -> GLVControl
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
"""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
project_channel_scale = 2
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
cross_attn_insert_idx = [6, 3]
self.project_modules = nn.ModuleList()
for i in range(len(cond_output_channels)):
self.project_modules.append(ZeroSFT(
project_channels[i], cond_output_channels[i],
concat_channels=concat_channels[i],
dtype=dtype, device=device, operations=operations,
))
for i in cross_attn_insert_idx:
self.project_modules.insert(i, ZeroCrossAttn(
cond_output_channels[i], concat_channels[i],
dtype=dtype, device=device, operations=operations,
))

View File

@ -0,0 +1,103 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
class SUPIRPatch:
"""
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
Runs GLVControl lazily on first patch invocation per step, applies adapters through
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
"""
SIGMA_MAX = 14.6146
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
self.hint_latent = hint_latent # encoded LQ image latent
self.strength_start = strength_start
self.strength_end = strength_end
self.cached_features = None
self.adapter_idx = 0
self.control_idx = 0
self.current_control_idx = 0
self.active = True
def _ensure_features(self, kwargs):
"""Run GLVControl on first call per step, cache results."""
if self.cached_features is not None:
return
x = kwargs["x"]
b = x.shape[0]
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
if hint.shape[0] != b:
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
self.cached_features = self.model_patch.model.control_model(
hint, kwargs["timesteps"], x,
kwargs["context"], kwargs["y"]
)
self.adapter_idx = len(self.project_modules) - 1
self.control_idx = len(self.cached_features) - 1
def _get_control_scale(self, kwargs):
if self.strength_start == self.strength_end:
return self.strength_end
sigma = kwargs["transformer_options"].get("sigmas")
if sigma is None:
return self.strength_end
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
t = min(s / self.SIGMA_MAX, 1.0)
return t * (self.strength_start - self.strength_end) + self.strength_end
def middle_after(self, kwargs):
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
self.cached_features = None # reset from previous step
self.current_scale = self._get_control_scale(kwargs)
self.active = self.current_scale > 0
if not self.active:
return {"h": kwargs["h"]}
self._ensure_features(kwargs)
h = kwargs["h"]
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return {"h": h}
def output_block(self, h, hsp, transformer_options):
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
if not self.active:
return h, hsp
self.current_control_idx = self.control_idx
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return h, None
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
block_type, _ = transformer_options["block"]
if block_type == "output" and self.active and self.cached_features is not None:
x = self.project_modules[self.adapter_idx](
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
)
self.adapter_idx -= 1
return layer(x, output_shape=output_shape)
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.cached_features = None
if self.hint_latent is not None:
self.hint_latent = self.hint_latent.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
def register(self, model_patcher):
"""Register all patches on a cloned model patcher."""
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
model_patcher.set_model_output_block_patch(self.output_block)
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")

View File

@ -53,6 +53,8 @@ import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15 import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -577,8 +579,8 @@ class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight) self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone())
self.cc_projection.bias.copy_(cc_projection_bias) self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone())
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
@ -1962,3 +1964,18 @@ class Kandinsky5Image(Kandinsky5):
class RT_DETR_v4(BaseModel): class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
class ErnieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)

View File

@ -713,6 +713,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
dit_config = {}
dit_config["image_model"] = "ernie"
return dit_config
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "SAM3"
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
dit_config["image_model"] = "SAM31"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None
@ -868,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):
# SAM3: detector.* and tracker.* at top level, no common prefix
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
return ""
candidates = ["model.diffusion_model.", #ldm/sgm models candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models "model.model.", #audio models
"net.", #cosmos "net.", #cosmos

View File

@ -1732,6 +1732,21 @@ def supports_mxfp8_compute(device=None):
return True return True
def supports_fp64(device=None):
if is_device_mps(device):
return False
if is_intel_xpu():
return False
if is_directml_enabled():
return False
if is_ixuca():
return False
return True
def extended_fp16_support(): def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older # TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7): if torch_version_numeric < (2, 7):
@ -1786,7 +1801,7 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary() return torch.cuda.memory.memory_summary()
return "" return ""
class InterruptProcessingException(Exception): class InterruptProcessingException(BaseException):
pass pass
interrupt_processing_mutex = threading.RLock() interrupt_processing_mutex = threading.RLock()

View File

@ -506,6 +506,10 @@ class ModelPatcher:
def set_model_noise_refiner_patch(self, patch): def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner") self.set_model_patch(patch, "noise_refiner")
def set_model_middle_block_after_patch(self, patch):
self.set_model_patch(patch, "middle_block_after_patch")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x rope_options["scale_x"] = scale_x
@ -681,9 +685,9 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches: if key not in self.patches and not force_cast:
return weight return weight
inplace_update = self.weight_inplace_update or inplace_update inplace_update = self.weight_inplace_update or inplace_update
@ -691,7 +695,7 @@ class ModelPatcher:
if key not in self.backup and not return_weight: if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
if device_to is not None: if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else: else:
@ -699,9 +703,10 @@ class ModelPatcher:
if convert_func is not None: if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True) temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
if set_func is None: if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if key in self.patches:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight: if return_weight:
return out_weight return out_weight
elif inplace_update: elif inplace_update:
@ -1580,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher):
key = key_param_name_to_key(n, param_key) key = key_param_name_to_key(n, param_key)
if key in self.backup: if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to) self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
weight, _, _ = get_key_weight(self.model, key) weight, _, _ = get_key_weight(self.model, key)
if weight is not None: if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
@ -1605,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher):
m._v = vbar.alloc(v_weight_size) m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size allocated_size += v_weight_size
for param in params:
if param not in ("weight", "bias"):
force_load_param(self, param, device_to)
else: else:
for param in params: for param in params:
key = key_param_name_to_key(n, param) key = key_param_name_to_key(n, param)

View File

@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if param is None: if param is None:
continue continue
p = fn(param) p = fn(param)
if p.is_inference(): if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone() p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():

View File

@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
@ -62,6 +63,7 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15 import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -804,6 +806,24 @@ class VAE:
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = self.first_stage_model.latent_channels
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -1235,6 +1255,7 @@ class TEModel(Enum):
QWEN35_4B = 25 QWEN35_4B = 25
QWEN35_9B = 26 QWEN35_9B = 26
QWEN35_27B = 27 QWEN35_27B = 27
MINISTRAL_3_3B = 28
def detect_te_model(sd): def detect_te_model(sd):
@ -1301,6 +1322,8 @@ def detect_te_model(sd):
return TEModel.MISTRAL3_24B return TEModel.MISTRAL3_24B
else: else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2 return TEModel.MISTRAL3_24B_PRUNED_FLUX2
if weight.shape[0] == 3072:
return TEModel.MINISTRAL_3_3B
return TEModel.LLAMA3_8 return TEModel.LLAMA3_8
return None return None
@ -1458,6 +1481,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_06B: elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
elif te_model == TEModel.MINISTRAL_3_3B:
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:

View File

@ -26,6 +26,7 @@ import comfy.text_encoders.z_image
import comfy.text_encoders.anima import comfy.text_encoders.anima
import comfy.text_encoders.ace15 import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -1749,6 +1750,88 @@ class RT_DETR_v4(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None return None
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
class ErnieImage(supported_models_base.BASE):
unet_config = {
"image_model": "ernie",
}
sampling_settings = {
"multiplier": 1000.0,
"shift": 3.0,
}
memory_usage_factor = 10.0
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ErnieImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
class SAM3(supported_models_base.BASE):
unet_config = {"image_model": "SAM3"}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
unet_extra_prefix = ""
def process_clip_state_dict(self, state_dict):
clip_keys = getattr(self, "_clip_stash", {})
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
def process_unet_state_dict(self, state_dict):
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
# SAM3.1: remap tracker.model.* -> tracker.*
for k in list(state_dict.keys()):
if k.startswith("tracker.model."):
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
state_dict.pop(k)
# Split fused QKV projections
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
t = state_dict.pop(k)
base, suffix = k.rsplit(".in_proj_", 1)
s = ".weight" if suffix == "weight" else ".bias"
d = t.shape[0] // 3
state_dict[base + ".q_proj" + s] = t[:d]
state_dict[base + ".k_proj" + s] = t[d:2*d]
state_dict[base + ".v_proj" + s] = t[2*d:]
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
for k in list(state_dict.keys()):
if "sam_mask_decoder.transformer." not in k:
continue
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
return model_base.SAM3(self, device=device)
def clip_target(self, state_dict={}):
import comfy.text_encoders.sam3_clip
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -0,0 +1,38 @@
from .flux import Mistral3Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
class Ministral3_3BTokenizer(Mistral3Tokenizer):
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
class ErnieTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ministral3_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ErnieTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class ErnieTEModel_(ErnieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ErnieTEModel_

View File

@ -116,9 +116,9 @@ class MistralTokenizerClass:
return LlamaTokenizerFast(**kwargs) return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer): class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None) self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"tekken_model": self.tekken_data} return {"tekken_model": self.tekken_data}

View File

@ -60,6 +60,30 @@ class Mistral3Small24BConfig:
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
@dataclass
class Ministral3_3BConfig:
vocab_size: int = 131072
hidden_size: int = 3072
intermediate_size: int = 9216
num_hidden_layers: int = 26
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 262144
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [2]
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
vocab_size: int = 151936 vocab_size: int = 151936
@ -946,6 +970,15 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ministral3_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module): class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -0,0 +1,97 @@
import re
from comfy import sd1_clip
SAM3_CLIP_CONFIG = {
"architectures": ["CLIPTextModel"],
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"max_position_embeddings": 32,
"projection_dim": 512,
"vocab_size": 49408,
"layer_norm_eps": 1e-5,
"eos_token_id": 49407,
}
class SAM3ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
class SAM3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
self.disable_weights = True
def _parse_prompts(text):
"""Split comma-separated prompts with optional :N max detections per category"""
text = text.replace("(", "").replace(")", "")
parts = [p.strip() for p in text.split(",") if p.strip()]
result = []
for part in parts:
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
if m:
text_part = m.group(1).strip()
val = m.group(2)
max_det = max(1, round(float(val)))
result.append((text_part, max_det))
else:
result.append((part, 1))
return result
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
parsed = _parse_prompts(text)
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
# Tokenize each prompt part separately, store per-part batches and metadata
inner = getattr(self, self.clip)
per_prompt = []
for prompt_text, max_det in parsed:
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
per_prompt.append((batches, max_det))
# Main output uses first prompt's tokens (for compatibility)
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
return out
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
def encode_token_weights(self, token_weight_pairs):
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
if per_prompt is None:
return super().encode_token_weights(token_weight_pairs)
# Encode each prompt separately, pack into extra dict
inner = getattr(self, self.clip)
multi_cond = []
first_pooled = None
for batches, max_det in per_prompt:
out = inner.encode_token_weights(batches)
cond, pooled = out[0], out[1]
extra = out[2] if len(out) > 2 else {}
if first_pooled is None:
first_pooled = pooled
multi_cond.append({
"cond": cond,
"attention_mask": extra.get("attention_mask"),
"max_detections": max_det,
})
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
main = multi_cond[0]
main_extra = {}
if main["attention_mask"] is not None:
main_extra["attention_mask"] = main["attention_mask"]
main_extra["sam3_multi_cond"] = multi_cond
return (main["cond"], first_pooled, main_extra)

View File

@ -52,6 +52,26 @@ class TaskImageContent(BaseModel):
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class TaskVideoContentUrl(BaseModel):
url: str = Field(...)
class TaskVideoContent(BaseModel):
type: str = Field("video_url")
video_url: TaskVideoContentUrl = Field(...)
role: str = Field("reference_video")
class TaskAudioContentUrl(BaseModel):
url: str = Field(...)
class TaskAudioContent(BaseModel):
type: str = Field("audio_url")
audio_url: TaskAudioContentUrl = Field(...)
role: str = Field("reference_audio")
class Text2VideoTaskCreationRequest(BaseModel): class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...) model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1) content: list[TaskTextContent] = Field(..., min_length=1)
@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel):
generate_audio: bool | None = Field(...) generate_audio: bool | None = Field(...)
class Seedance2TaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
generate_audio: bool | None = Field(None)
resolution: str | None = Field(None)
ratio: str | None = Field(None)
duration: int | None = Field(None, ge=4, le=15)
seed: int | None = Field(None, ge=0, le=2147483647)
watermark: bool | None = Field(None)
class TaskCreationResponse(BaseModel): class TaskCreationResponse(BaseModel):
id: str = Field(...) id: str = Field(...)
@ -77,12 +108,62 @@ class TaskStatusResult(BaseModel):
video_url: str = Field(...) video_url: str = Field(...)
class TaskStatusUsage(BaseModel):
completion_tokens: int = Field(0)
total_tokens: int = Field(0)
class TaskStatusResponse(BaseModel): class TaskStatusResponse(BaseModel):
id: str = Field(...) id: str = Field(...)
model: str = Field(...) model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None) error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None) content: TaskStatusResult | None = Field(None)
usage: TaskStatusUsage | None = Field(None)
class GetAssetResponse(BaseModel):
id: str = Field(...)
name: str | None = Field(None)
url: str | None = Field(None)
asset_type: str = Field(...)
group_id: str = Field(...)
status: str = Field(...)
error: TaskStatusError | None = Field(None)
class SeedanceCreateVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
h5_link: str = Field(...)
class SeedanceGetVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
status: str = Field(...)
group_id: str | None = Field(None)
error_code: str | None = Field(None)
error_message: str | None = Field(None)
class SeedanceCreateAssetRequest(BaseModel):
group_id: str = Field(...)
url: str = Field(...)
asset_type: str = Field(...)
name: str | None = Field(None, max_length=64)
project_name: str | None = Field(None)
class SeedanceCreateAssetResponse(BaseModel):
asset_id: str = Field(...)
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
}
RECOMMENDED_PRESETS = [ RECOMMENDED_PRESETS = [
@ -112,6 +193,19 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None), ("Custom", None, None),
] ]
# Seedance 2.0 reference video pixel count limits per model and output resolution.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
"1080p": {"min": 409_600, "max": 2_073_600},
},
"dreamina-seedance-2-0-fast-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
}
# The time in this dictionary are given for 10 seconds duration. # The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = { VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": { "seedance-1-0-lite-t2v-250428": {

File diff suppressed because it is too large Load Diff

View File

@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
( (
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration"); $dur := $lookup(widgets, "model.duration");
$refs := inputGroups["model.reference_images"]; $refs := $lookup(inputGroups, "model.reference_images");
$rate := $res = "720p" ? 0.07 : 0.05; $rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43; $price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price} {"type":"usd","usd": $price}

View File

@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
obj_result = None
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
obj_result.obj, obj_result.obj if obj_result else None,
obj_result.texture, obj_result.texture if obj_result else None,
) )
@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
obj_result.obj, None,
obj_result.texture, None,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), None,
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), None,
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), None,
) )

View File

@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode):
f"must equal the global duration ({duration}s)." f"must equal the global duration ({duration}s)."
) )
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
mode="pro" if resolution == "1080p" else "std", mode=mode,
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
shot_type="customize" if multi_shot else None, shot_type="customize" if multi_shot else None,
@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True, optional=True,
tooltip="Up to 6 additional reference images.", tooltip="Up to 6 additional reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images", "reference_images",
tooltip="Up to 7 reference images.", tooltip="Up to 7 reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode):
image_list: list[OmniParamImage] = [] image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=["16:9", "9:16", "1:1"], options=["16:9", "9:16", "1:1"],
@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode):
start_frame: Input.Image | None = None, start_frame: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
_ = seed _ = seed
mode = "pro" if model["resolution"] == "1080p" else "std" if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
custom_multi_shot = False custom_multi_shot = False
if multi_shot["multi_shot"] == "disabled": if multi_shot["multi_shot"] == "disabled":
shot_type = None shot_type = None
@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path=poll_path), ApiEndpoint(path=poll_path),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
], ],
), ),
], ],
@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
image=image_url, image=image_url,
image_tail=image_tail_url, image_tail=image_tail_url,
prompt=prompt, prompt=prompt,
mode="pro" if model["resolution"] == "1080p" else "std", mode=mode,
duration=str(duration), duration=str(duration),
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
), ),
@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -357,13 +357,17 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode): class OpenAIGPTImage1(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="OpenAIGPTImage1", node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1.5", display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI", category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.", description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[ inputs=[
@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"size", "size",
default="auto", default="auto",
options=["auto", "1024x1024", "1024x1536", "1536x1024"], options=[
"auto",
"1024x1024",
"1024x1536",
"1536x1024",
"2048x2048",
"2048x1152",
"1152x2048",
"3840x2160",
"2160x3840",
],
tooltip="Image size", tooltip="Image size",
optional=True, optional=True,
), ),
@ -427,8 +441,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["gpt-image-1", "gpt-image-1.5"], options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"],
default="gpt-image-1.5", default="gpt-image-2",
optional=True, optional=True,
), ),
], ],
@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
expr=""" expr="""
( (
$ranges := { $ranges := {
"low": [0.011, 0.02], "gpt-image-1": {
"medium": [0.046, 0.07], "low": [0.011, 0.02],
"high": [0.167, 0.3] "medium": [0.042, 0.07],
"high": [0.167, 0.25]
},
"gpt-image-1.5": {
"low": [0.009, 0.02],
"medium": [0.034, 0.062],
"high": [0.133, 0.22]
},
"gpt-image-2": {
"low": [0.0048, 0.012],
"medium": [0.041, 0.112],
"high": [0.165, 0.43]
}
}; };
$range := $lookup($ranges, widgets.quality); $range := $lookup($lookup($ranges, widgets.model), widgets.quality);
$n := widgets.n; $nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
($n = 1) ($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
: { : {
"type":"range_usd", "type":"range_usd",
"min_usd": $range[0], "min_usd": $range[0] * $n,
"max_usd": $range[1], "max_usd": $range[1] * $n,
"format": { "suffix": " x " & $string($n) & "/Run" } "format": { "suffix": "/Run", "approximate": true }
} }
) )
""", """,
@ -483,10 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode):
if mask is not None and image is None: if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image") raise ValueError("Cannot use a mask without an input image")
if model in ("gpt-image-1", "gpt-image-1.5"):
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
if model == "gpt-image-1": if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1 price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5": elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5 price_extractor = calculate_tokens_price_image_1_5
elif model == "gpt-image-2":
price_extractor = calculate_tokens_price_image_2_0
if background == "transparent":
raise ValueError("Transparent background is not supported for GPT Image 2 model")
else: else:
raise ValueError(f"Unknown model: {model}") raise ValueError(f"Unknown model: {model}")

View File

@ -17,6 +17,44 @@ from comfy_api_nodes.util import (
) )
from comfy_extras.nodes_images import SVG from comfy_extras.nodes_images import SVG
_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"]
def _arrow_sampling_inputs():
"""Shared sampling inputs for all Arrow model variants."""
return [
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
]
class QuiverTextToSVGNode(IO.ComfyNode): class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod @classmethod
@ -39,6 +77,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
default="", default="",
tooltip="Additional style or formatting guidance.", tooltip="Additional style or formatting guidance.",
optional=True, optional=True,
advanced=True,
), ),
IO.Autogrow.Input( IO.Autogrow.Input(
"reference_images", "reference_images",
@ -53,43 +92,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
), ),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS],
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG generation.", tooltip="Model to use for SVG generation.",
), ),
IO.Int.Input( IO.Int.Input(
@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""", depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model, "max")
? {"type":"usd","usd":0.3575}
: $contains(widgets.model, "preview")
? {"type":"usd","usd":0.429}
: {"type":"usd","usd":0.286}
)
""",
), ),
) )
@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode):
"auto_crop", "auto_crop",
default=False, default=False,
tooltip="Automatically crop to the dominant subject.", tooltip="Automatically crop to the dominant subject.",
advanced=True,
), ),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"arrow-preview", m,
[ [
IO.Int.Input( IO.Int.Input(
"target_size", "target_size",
@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode):
min=128, min=128,
max=4096, max=4096,
tooltip="Square resize target in pixels.", tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True, advanced=True,
), ),
*_arrow_sampling_inputs(),
], ],
), )
for m in _ARROW_MODELS
], ],
tooltip="Model to use for SVG vectorization.", tooltip="Model to use for SVG vectorization.",
), ),
@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""", depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model, "max")
? {"type":"usd","usd":0.3575}
: $contains(widgets.model, "preview")
? {"type":"usd","usd":0.429}
: {"type":"usd","usd":0.286}
)
""",
), ),
) )

View File

@ -0,0 +1,287 @@
import base64
import json
import logging
import time
from urllib.parse import urljoin
import aiohttp
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
upload_video_to_comfyapi,
validate_string,
)
from comfy_api_nodes.util._helpers import (
default_base_url,
get_auth_header,
get_node_id,
is_processing_interrupted,
)
from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted
from server import PromptServer
logger = logging.getLogger(__name__)
class SoniloVideoToMusic(IO.ComfyNode):
"""Generate music from video using Sonilo's AI model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music",
category="api node/audio/Sonilo",
description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.",
inputs=[
IO.Video.Input(
"video",
tooltip="Input video to generate music from. Maximum duration: 6 minutes.",
),
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Optional text prompt to guide music generation. "
"Leave empty for best quality - the model will fully analyze the video content.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
"service but kept for graph consistency.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}',
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
video_url = await upload_video_to_comfyapi(cls, video, max_duration=360)
form = aiohttp.FormData()
form.add_field("video_url", video_url)
if prompt.strip():
form.add_field("prompt", prompt.strip())
audio_bytes = await _stream_sonilo_music(
cls,
ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"),
form,
)
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
class SoniloTextToMusic(IO.ComfyNode):
"""Generate music from a text prompt using Sonilo's AI model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="api node/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt describing the music to generate.",
),
IO.Int.Input(
"duration",
default=0,
min=0,
max=360,
tooltip="Target duration in seconds. Set to 0 to let the model "
"infer the duration from the prompt. Maximum: 6 minutes.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
"service but kept for graph consistency.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""
(
widgets.duration > 0
? {"type":"usd","usd": 0.005 * widgets.duration}
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
duration: int = 0,
seed: int = 0,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
form = aiohttp.FormData()
form.add_field("prompt", prompt)
if duration > 0:
form.add_field("duration", str(duration))
audio_bytes = await _stream_sonilo_music(
cls,
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
form,
)
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
async def _stream_sonilo_music(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
form: aiohttp.FormData,
) -> bytes:
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
headers: dict[str, str] = {}
headers.update(get_auth_header(cls))
headers.update(endpoint.headers)
node_id = get_node_id(cls)
start_ts = time.monotonic()
last_chunk_status_ts = 0.0
audio_streams: dict[int, list[bytes]] = {}
title: str | None = None
timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
PromptServer.instance.send_progress_text("Status: Queued", node_id)
async with session.post(url, data=form, headers=headers) as resp:
if resp.status >= 400:
msg = await _extract_error_message(resp)
raise Exception(f"Sonilo API error ({resp.status}): {msg}")
while True:
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
raw_line = await resp.content.readline()
if not raw_line:
break
line = raw_line.decode("utf-8").strip()
if not line:
continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
logger.warning("Sonilo: skipping malformed NDJSON line")
continue
evt_type = evt.get("type")
if evt_type == "error":
code = evt.get("code", "UNKNOWN")
message = evt.get("message", "Unknown error")
raise Exception(f"Sonilo generation error ({code}): {message}")
if evt_type == "duration":
duration_sec = evt.get("duration_sec")
if duration_sec is not None:
PromptServer.instance.send_progress_text(
f"Status: Generating\nVideo duration: {duration_sec:.1f}s",
node_id,
)
elif evt_type in ("titles", "title"):
# v2m sends a "titles" list, t2m sends a scalar "title"
if evt_type == "titles":
titles = evt.get("titles", [])
if titles:
title = titles[0]
else:
title = evt.get("title") or title
if title:
PromptServer.instance.send_progress_text(
f"Status: Generating\nTitle: {title}",
node_id,
)
elif evt_type == "audio_chunk":
stream_idx = evt.get("stream_index", 0)
chunk_data = base64.b64decode(evt["data"])
if stream_idx not in audio_streams:
audio_streams[stream_idx] = []
audio_streams[stream_idx].append(chunk_data)
now = time.monotonic()
if now - last_chunk_status_ts >= 1.0:
total_chunks = sum(len(chunks) for chunks in audio_streams.values())
elapsed = int(now - start_ts)
status_lines = ["Status: Receiving audio"]
if title:
status_lines.append(f"Title: {title}")
status_lines.append(f"Chunks received: {total_chunks}")
status_lines.append(f"Time elapsed: {elapsed}s")
PromptServer.instance.send_progress_text("\n".join(status_lines), node_id)
last_chunk_status_ts = now
elif evt_type == "complete":
break
if not audio_streams:
raise Exception("Sonilo API returned no audio data.")
PromptServer.instance.send_progress_text("Status: Completed", node_id)
selected_stream = 0 if 0 in audio_streams else min(audio_streams)
return b"".join(audio_streams[selected_stream])
async def _extract_error_message(resp: aiohttp.ClientResponse) -> str:
"""Extract a human-readable error message from an HTTP error response."""
try:
error_body = await resp.json()
detail = error_body.get("detail", {})
if isinstance(detail, dict):
return detail.get("message", str(detail))
return str(detail)
except Exception:
return await resp.text()
class SoniloExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SoniloVideoToMusic, SoniloTextToMusic]
async def comfy_entrypoint() -> SoniloExtension:
return SoniloExtension()

View File

@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""", expr="""{"type":"usd","usd":0.4}""",
), ),
) )
@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""", expr="""{"type":"usd","usd":0.6}""",
), ),
) )
@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""", expr="""{"type":"usd","usd":0.02}""",
), ),
) )

View File

@ -24,8 +24,9 @@ from comfy_api_nodes.util import (
AVERAGE_DURATION_VIDEO_GEN = 32 AVERAGE_DURATION_VIDEO_GEN = 32
MODELS_MAP = { MODELS_MAP = {
"veo-2.0-generate-001": "veo-2.0-generate-001", "veo-2.0-generate-001": "veo-2.0-generate-001",
"veo-3.1-generate": "veo-3.1-generate-preview", "veo-3.1-generate": "veo-3.1-generate-001",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview", "veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
"veo-3.1-lite": "veo-3.1-lite-generate-001",
"veo-3.0-generate-001": "veo-3.0-generate-001", "veo-3.0-generate-001": "veo-3.0-generate-001",
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
} }
@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
raise Exception("Video generation completed but no video was returned") raise Exception("Video generation completed but no video was returned")
class Veo3VideoGenerationNode(VeoVideoGenerationNode): class Veo3VideoGenerationNode(IO.ComfyNode):
""" """Generates videos from text prompts using Google's Veo 3 API."""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
default="16:9", default="16:9",
tooltip="Aspect ratio of the output video", tooltip="Aspect ratio of the output video",
), ),
IO.Combo.Input(
"resolution",
options=["720p", "1080p", "4k"],
default="720p",
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
optional=True,
),
IO.String.Input( IO.String.Input(
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Int.Input( IO.Int.Input(
"duration_seconds", "duration_seconds",
default=8, default=8,
min=8, min=4,
max=8, max=8,
step=1, step=2,
display_mode=IO.NumberDisplay.number, display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", tooltip="Duration of the output video in seconds",
optional=True, optional=True,
), ),
IO.Boolean.Input( IO.Boolean.Input(
@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
options=[ options=[
"veo-3.1-generate", "veo-3.1-generate",
"veo-3.1-fast-generate", "veo-3.1-fast-generate",
"veo-3.1-lite",
"veo-3.0-generate-001", "veo-3.0-generate-001",
"veo-3.0-fast-generate-001", "veo-3.0-fast-generate-001",
], ],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation", tooltip="Veo 3 model to use for video generation",
optional=True, optional=True,
), ),
@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
expr=""" expr="""
( (
$m := widgets.model; $m := widgets.model;
$r := widgets.resolution;
$a := widgets.generate_audio; $a := widgets.generate_audio;
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) $seconds := widgets.duration_seconds;
? {"type":"usd","usd": ($a ? 1.2 : 0.8)} $pps :=
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) $contains($m, "lite")
? {"type":"usd","usd": ($a ? 3.2 : 1.6)} ? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2} : $contains($m, "3.1-fast")
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
: $contains($m, "3.1-generate")
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
: $contains($m, "3.0-fast")
? ($a ? 0.15 : 0.10)
: ($a ? 0.40 : 0.20);
{"type":"usd","usd": $pps * $seconds}
) )
""", """,
), ),
) )
@classmethod
async def execute(
cls,
prompt,
aspect_ratio="16:9",
resolution="720p",
negative_prompt="",
duration_seconds=8,
enhance_prompt=True,
person_generation="ALLOW",
seed=0,
image=None,
model="veo-3.0-generate-001",
generate_audio=False,
):
if resolution == "4k" and ("lite" in model or "3.0" in model):
raise Exception("4K resolution is not supported by the veo-3.1-lite or veo-3.0 models.")
model = MODELS_MAP[model]
instances = [{"prompt": prompt}]
if image is not None:
image_base64 = tensor_to_base64_string(image)
if image_base64:
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
parameters = {
"aspectRatio": aspect_ratio,
"personGeneration": person_generation,
"durationSeconds": duration_seconds,
"enhancePrompt": True,
"generateAudio": generate_audio,
}
if negative_prompt:
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
if "veo-3.1" in model:
parameters["resolution"] = resolution
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=instances,
parameters=parameters,
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(operationName=initial_response.name),
poll_interval=9.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class Veo3FirstLastFrameNode(IO.ComfyNode): class Veo3FirstLastFrameNode(IO.ComfyNode):
@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
default="", default="",
tooltip="Negative text prompt to guide what to avoid in the video", tooltip="Negative text prompt to guide what to avoid in the video",
), ),
IO.Combo.Input("resolution", options=["720p", "1080p"]), IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=["16:9", "9:16"], options=["16:9", "9:16"],
@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
IO.Image.Input("last_frame", tooltip="End frame"), IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["veo-3.1-generate", "veo-3.1-fast-generate"], options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
default="veo-3.1-fast-generate",
), ),
IO.Boolean.Input( IO.Boolean.Input(
"generate_audio", "generate_audio",
@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
expr=""" expr="""
( (
$prices := {
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
};
$m := widgets.model; $m := widgets.model;
$ga := (widgets.generate_audio = "true"); $r := widgets.resolution;
$ga := widgets.generate_audio;
$seconds := widgets.duration; $seconds := widgets.duration;
$modelKey := $pps :=
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : $contains($m, "lite")
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : ? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
""; : $contains($m, "fast")
$audioKey := $ga ? "audio" : "no_audio"; ? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
$modelPrices := $lookup($prices, $modelKey); : ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
$pps := $lookup($modelPrices, $audioKey); {"type":"usd","usd": $pps * $seconds}
($pps != null)
? {"type":"usd","usd": $pps * $seconds}
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
) )
""", """,
), ),
@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
model: str, model: str,
generate_audio: bool, generate_audio: bool,
): ):
if "lite" in model and resolution == "4k":
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
model = MODELS_MAP[model] model = MODELS_MAP[model]
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
data=VeoGenVidPollRequest( data=VeoGenVidPollRequest(
operationName=initial_response.name, operationName=initial_response.name,
), ),
poll_interval=5.0, poll_interval=9.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN, estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
) )

View File

@ -19,6 +19,7 @@ from .conversions import (
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
pil_to_bytesio, pil_to_bytesio,
resize_mask_to_image, resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string, tensor_to_base64_string,
tensor_to_bytesio, tensor_to_bytesio,
tensor_to_pil, tensor_to_pil,
@ -90,6 +91,7 @@ __all__ = [
"image_tensor_pair_to_batch", "image_tensor_pair_to_batch",
"pil_to_bytesio", "pil_to_bytesio",
"resize_mask_to_image", "resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string", "tensor_to_base64_string",
"tensor_to_bytesio", "tensor_to_bytesio",
"tensor_to_pil", "tensor_to_pil",

View File

@ -156,6 +156,7 @@ async def poll_op(
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> M: ) -> M:
raw = await poll_op_raw( raw = await poll_op_raw(
cls, cls,
@ -176,6 +177,7 @@ async def poll_op(
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint, cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout, cancel_timeout=cancel_timeout,
extra_text=extra_text,
) )
if not isinstance(raw, dict): if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@ -260,6 +262,7 @@ async def poll_op_raw(
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
@ -299,6 +302,7 @@ async def poll_op_raw(
price=state.price, price=state.price,
is_queued=state.is_queued, is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed), processing_elapsed_seconds=int(proc_elapsed),
extra_text=extra_text,
) )
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
except Exception as exc: except Exception as exc:
@ -389,6 +393,7 @@ async def poll_op_raw(
price=state.price, price=state.price,
is_queued=False, is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed), processing_elapsed_seconds=int(state.base_processing_elapsed),
extra_text=extra_text,
) )
return resp_json return resp_json
@ -462,6 +467,7 @@ def _display_time_progress(
price: float | None = None, price: float | None = None,
is_queued: bool | None = None, is_queued: bool | None = None,
processing_elapsed_seconds: int | None = None, processing_elapsed_seconds: int | None = None,
extra_text: str | None = None,
) -> None: ) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False: if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -469,7 +475,8 @@ def _display_time_progress(
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else: else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s" time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price) text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
_display_text(node_cls, text, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]: async def _diagnose_connectivity() -> dict[str, bool]:

View File

@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr return img_byte_arr
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded down to even values (many codecs require divisible-by-2).
"""
pixels = src_w * src_h
if pixels <= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = max(2, int(src_w * scale))
new_h = max(2, int(src_h * scale))
new_w -= new_w % 2
new_h -= new_h % 2
return new_w, new_h
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels.""" """Downscale input image tensor to roughly the specified total pixels.
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
and is compatible with codecs that require even dimensions (e.g. yuv420p).
"""
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
total = int(total_pixels) dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) if dims is None:
if scale_by >= 1:
return image return image
width = round(samples.shape[3] * scale_by) new_w, new_h = dims
height = round(samples.shape[2] * scale_by) return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
"""Downscale input image tensor so the largest dimension is at most max_side pixels.""" """Downscale input image tensor so the largest dimension is at most max_side pixels."""
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
height, width = samples.shape[2], samples.shape[3] height, width = samples.shape[2], samples.shape[3]
@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
out_w, out_h = scale_dims
output_buffer = BytesIO()
input_container = None
output_container = None
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
video_stream.width = out_w
video_stream.height = out_h
video_stream.pix_fmt = "yuv420p"
audio_stream = None
for stream in input_container.streams:
if isinstance(stream, av.AudioStream):
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
break
for frame in input_container.decode(video=0):
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
for packet in video_stream.encode():
output_container.mux(packet)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():
output_container.mux(packet)
output_container.close()
input_container.close()
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point: if wav.dtype.is_floating_point:

View File

@ -0,0 +1,258 @@
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
class FilmConv2d(nn.Module):
"""Conv2d with optional LeakyReLU and FILM-style padding."""
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
super().__init__()
self.even_pad = not size % 2
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
self.activation = nn.LeakyReLU(0.2) if activation else None
def forward(self, x):
if self.even_pad:
x = F.pad(x, (0, 1, 0, 1))
x = self.conv(x)
if self.activation is not None:
x = self.activation(x)
return x
def _warp_core(image, flow, grid_x, grid_y):
dtype = image.dtype
H, W = flow.shape[2], flow.shape[3]
dx = flow[:, 0].float() / (W * 0.5)
dy = flow[:, 1].float() / (H * 0.5)
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
def build_image_pyramid(image, pyramid_levels):
pyramid = [image]
for _ in range(1, pyramid_levels):
image = F.avg_pool2d(image, 2, 2)
pyramid.append(image)
return pyramid
def flow_pyramid_synthesis(residual_pyramid):
flow = residual_pyramid[-1]
flow_pyramid = [flow]
for residual_flow in residual_pyramid[:-1][::-1]:
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
flow_pyramid.append(flow)
flow_pyramid.reverse()
return flow_pyramid
def multiply_pyramid(pyramid, scalar):
return [image * scalar[:, None, None, None] for image in pyramid]
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
def concatenate_pyramids(pyramid1, pyramid2):
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
class SubTreeExtractor(nn.Module):
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
super().__init__()
convs = []
for i in range(n_layers):
out_ch = channels << i
convs.append(nn.Sequential(
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
in_channels = out_ch
self.convs = nn.ModuleList(convs)
def forward(self, image, n):
head = image
pyramid = []
for i, layer in enumerate(self.convs):
head = layer(head)
pyramid.append(head)
if i < n - 1:
head = F.avg_pool2d(head, 2, 2)
return pyramid
class FeatureExtractor(nn.Module):
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
super().__init__()
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
self.sub_levels = sub_levels
def forward(self, image_pyramid):
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
for i in range(len(image_pyramid))]
feature_pyramid = []
for i in range(len(image_pyramid)):
features = sub_pyramids[i][0]
for j in range(1, self.sub_levels):
if j <= i:
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
feature_pyramid.append(features)
# Free sub-pyramids no longer needed by future levels
if i >= self.sub_levels - 1:
sub_pyramids[i - self.sub_levels + 1] = None
return feature_pyramid
class FlowEstimator(nn.Module):
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
super().__init__()
self._convs = nn.ModuleList()
for _ in range(num_convs):
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
in_channels = num_filters
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
def forward(self, features_a, features_b):
net = torch.cat([features_a, features_b], dim=1)
for conv in self._convs:
net = conv(net)
return net
class PyramidFlowEstimator(nn.Module):
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
super().__init__()
in_channels = filters << 1
predictors = []
for i in range(len(flow_convs)):
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
in_channels += filters << (i + 2)
self._predictor = predictors[-1]
self._predictors = nn.ModuleList(predictors[:-1][::-1])
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
levels = len(feature_pyramid_a)
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
residuals = [v]
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
for i, predictor in steps:
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
residuals.append(v_residual)
v = v.add_(v_residual)
residuals.reverse()
return residuals
def _get_fusion_channels(level, filters):
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
class Fusion(nn.Module):
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
super().__init__()
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
self.convs = nn.ModuleList()
in_channels = _get_fusion_channels(n_layers, filters)
increase = 0
for i in range(n_layers)[::-1]:
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
self.convs.append(nn.ModuleList([
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
in_channels = num_filters
increase = _get_fusion_channels(i, filters) - num_filters // 2
def forward(self, pyramid):
net = pyramid[-1]
for k, layers in enumerate(self.convs):
i = len(self.convs) - 1 - k
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
return self.output_conv(net)
class FILMNet(nn.Module):
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
super().__init__()
self.pyramid_levels = pyramid_levels
self.fusion_pyramid_levels = fusion_pyramid_levels
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
self._warp_grids = {}
def get_dtype(self):
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
def _build_warp_grids(self, H, W, device):
"""Pre-compute warp grids for all pyramid levels."""
if (H, W) in self._warp_grids:
return
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
for _ in range(self.pyramid_levels):
self._warp_grids[(H, W)] = (
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
)
H, W = H // 2, W // 2
def warp(self, image, flow):
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
return _warp_core(image, flow, grid_x, grid_y)
def extract_features(self, img):
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
feature_pyramid = self.extract(image_pyramid)
return image_pyramid, feature_pyramid
def forward(self, img0, img1, timestep=0.5, cache=None):
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
# Build warp targets and free full pyramids (only first fpl levels needed from here)
fpl = self.fusion_pyramid_levels
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
results = []
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
for idx in range(len(timesteps)):
batch_dt = dt_tensors[idx:idx + 1]
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
results.append(self.fuse(aligned))
del aligned
return torch.cat(results, dim=0)

View File

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
def _warp(img, flow, warp_grids):
B, _, H, W = img.shape
base_grid, flow_div = warp_grids[(H, W)]
flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
class Head(nn.Module):
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
super().__init__()
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
x = self.relu(self.cnn0(x))
x = self.relu(self.cnn1(x))
x = self.relu(self.cnn2(x))
return self.cnn3(x)
class ResConv(nn.Module):
def __init__(self, c, device=None, dtype=None, operations=ops):
super().__init__()
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
super().__init__()
self.conv0 = nn.Sequential(
nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
if flow is not None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
x = torch.cat((x, flow), 1)
feat = self.convblock(self.conv0(x))
tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
class IFNet(nn.Module):
def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
super().__init__()
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
self.scale_list = [16, 8, 4, 2, 1]
self.pad_align = 64
self._warp_grids = {}
def get_dtype(self):
return self.encode.cnn0.weight.dtype
def _build_warp_grids(self, H, W, device):
if (H, W) in self._warp_grids:
return
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
grid_y, grid_x = torch.meshgrid(
torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
self._warp_grids[(H, W)] = (
torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
def warp(self, img, flow):
return _warp(img, flow, self._warp_grids)
def extract_features(self, img):
"""Extract head features for a single frame. Can be cached across pairs."""
return self.encode(img)
def forward(self, img0, img1, timestep=0.5, cache=None):
if not isinstance(timestep, torch.Tensor):
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
B = img0.shape[0]
f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
flow = mask = feat = None
warped_img0, warped_img1 = img0, img1
for i, block in enumerate(self.blocks):
if flow is None:
flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
else:
fd, mask, feat = block(
torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
flow, scale=self.scale_list[i])
flow = flow.add_(fd)
warped_img0 = self.warp(img0, flow[:, :2])
warped_img1 = self.warp(img1, flow[:, 2:4])
return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
def detect_rife_config(state_dict):
head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
channels = []
for i in range(5):
key = f"blocks.{i}.conv0.1.0.weight"
if key in state_dict:
channels.append(state_dict[key].shape[0])
if len(channels) != 5:
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
return head_ch, channels

View File

@ -3,136 +3,136 @@ from typing_extensions import override
import comfy.model_management import comfy.model_management
import node_helpers import node_helpers
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, IO
class TextEncodeAceStepAudio(io.ComfyNode): class TextEncodeAceStepAudio(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return IO.Schema(
node_id="TextEncodeAceStepAudio", node_id="TextEncodeAceStepAudio",
category="conditioning", category="conditioning",
inputs=[ inputs=[
io.Clip.Input("clip"), IO.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True), IO.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True), IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
], ],
outputs=[io.Conditioning.Output()], outputs=[IO.Conditioning.Output()],
) )
@classmethod @classmethod
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics) tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return io.NodeOutput(conditioning) return IO.NodeOutput(conditioning)
class TextEncodeAceStepAudio15(io.ComfyNode): class TextEncodeAceStepAudio15(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return IO.Schema(
node_id="TextEncodeAceStepAudio1.5", node_id="TextEncodeAceStepAudio1.5",
category="conditioning", category="conditioning",
inputs=[ inputs=[
io.Clip.Input("clip"), IO.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True), IO.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True), IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Int.Input("bpm", default=120, min=10, max=300), IO.Int.Input("bpm", default=120, min=10, max=300),
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
io.Combo.Input("timesignature", options=['2', '3', '4', '6']), IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
], ],
outputs=[io.Conditioning.Output()], outputs=[IO.Conditioning.Output()],
) )
@classmethod @classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning) return IO.NodeOutput(conditioning)
class EmptyAceStepLatentAudio(io.ComfyNode): class EmptyAceStepLatentAudio(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return IO.Schema(
node_id="EmptyAceStepLatentAudio", node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio", display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio", category="latent/audio",
inputs=[ inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input( IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
), ),
], ],
outputs=[io.Latent.Output()], outputs=[IO.Latent.Output()],
) )
@classmethod @classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput: def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = int(seconds * 44100 / 512 / 8) length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"}) return IO.NodeOutput({"samples": latent, "type": "audio"})
class EmptyAceStep15LatentAudio(io.ComfyNode): class EmptyAceStep15LatentAudio(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return IO.Schema(
node_id="EmptyAceStep1.5LatentAudio", node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio", display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio", category="latent/audio",
inputs=[ inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
io.Int.Input( IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
), ),
], ],
outputs=[io.Latent.Output()], outputs=[IO.Latent.Output()],
) )
@classmethod @classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput: def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 48000 / 1920)) length = round((seconds * 48000 / 1920))
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"}) return IO.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceAudio(io.ComfyNode): class ReferenceAudio(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return IO.Schema(
node_id="ReferenceTimbreAudio", node_id="ReferenceTimbreAudio",
display_name="Reference Audio", display_name="Reference Audio",
category="advanced/conditioning/audio", category="advanced/conditioning/audio",
is_experimental=True, is_experimental=True,
description="This node sets the reference audio for ace step 1.5", description="This node sets the reference audio for ace step 1.5",
inputs=[ inputs=[
io.Conditioning.Input("conditioning"), IO.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True), IO.Latent.Input("latent", optional=True),
], ],
outputs=[ outputs=[
io.Conditioning.Output(), IO.Conditioning.Output(),
] ]
) )
@classmethod @classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput: def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
if latent is not None: if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
return io.NodeOutput(conditioning) return IO.NodeOutput(conditioning)
class AceExtension(ComfyExtension): class AceExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
TextEncodeAceStepAudio, TextEncodeAceStepAudio,
EmptyAceStepLatentAudio, EmptyAceStepLatentAudio,

View File

@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0
audio /= std audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}

View File

@ -0,0 +1,211 @@
import torch
from tqdm import tqdm
from typing_extensions import override
import comfy.model_patcher
import comfy.utils
import folder_paths
from comfy import model_management
from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config
from comfy_extras.frame_interpolation_models.film_net import FILMNet
from comfy_api.latest import ComfyExtension, io
FrameInterpolationModel = io.Custom("INTERP_MODEL")
class FrameInterpolationModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="FrameInterpolationModelLoader",
display_name="Load Frame Interpolation Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),
],
outputs=[
FrameInterpolationModel.Output(),
],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
model = cls._detect_and_load(sd)
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
model.eval().to(dtype)
patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=model_management.get_torch_device(),
offload_device=model_management.unet_offload_device(),
)
return io.NodeOutput(patcher)
@classmethod
def _detect_and_load(cls, sd):
# Try FILM
if "extract.extract_sublevels.convs.0.0.conv.weight" in sd:
model = FILMNet()
model.load_state_dict(sd)
return model
# Try RIFE (needs key remapping for raw checkpoints)
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
key_map = {}
for k in sd:
for i in range(5):
if k.startswith(f"block{i}."):
key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}"
if key_map:
sd = {key_map.get(k, k): v for k, v in sd.items()}
sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))}
try:
head_ch, channels = detect_rife_config(sd)
except (KeyError, ValueError):
raise ValueError("Unrecognized frame interpolation model format")
model = IFNet(head_ch=head_ch, channels=channels)
model.load_state_dict(sd)
return model
class FrameInterpolate(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="FrameInterpolate",
display_name="Frame Interpolate",
category="image/video",
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[
FrameInterpolationModel.Input("interp_model"),
io.Image.Input("images"),
io.Int.Input("multiplier", default=2, min=2, max=16),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
offload_device = model_management.intermediate_device()
num_frames = images.shape[0]
if num_frames < 2 or multiplier < 2:
return io.NodeOutput(images)
model_management.load_model_gpu(interp_model)
device = interp_model.load_device
dtype = interp_model.model_dtype()
inference_model = interp_model.model
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
H, W = images.shape[1], images.shape[2]
activation_mem = H * W * 3 * images.element_size() * 20
model_management.free_memory(activation_mem, device)
align = getattr(inference_model, "pad_align", 1)
# Prepare a single padded frame on device for determining output dimensions
def prepare_frame(idx):
frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
if align > 1:
from comfy.ldm.common_dit import pad_to_patch_size
frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
return frame
# Count total interpolation passes for progress bar
total_pairs = num_frames - 1
num_interp = multiplier - 1
total_steps = total_pairs * num_interp
pbar = comfy.utils.ProgressBar(total_steps)
tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
t_values = [t / multiplier for t in range(1, multiplier)]
out_dtype = model_management.intermediate_dtype()
total_out_frames = total_pairs * multiplier + 1
result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
result[0] = images[0].movedim(-1, 0).to(out_dtype)
out_idx = 1
# Pre-compute timestep tensor on device (padded dimensions needed)
sample = prepare_frame(0)
pH, pW = sample.shape[2], sample.shape[3]
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
ts_full = ts_full.expand(-1, 1, pH, pW)
del sample
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
feat_cache = {}
prev_frame = None
try:
for i in range(total_pairs):
img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
img1_single = prepare_frame(i + 1)
prev_frame = img1_single
# Cache features: img1 of pair N becomes img0 of pair N+1
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
feat_cache["img1"] = inference_model.extract_features(img1_single)
feat_cache["next"] = feat_cache["img1"]
used_multi = False
if multi_fn is not None:
# Models with timestep-independent flow can compute it once for all timesteps
try:
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
out_idx += num_interp
pbar.update(num_interp)
tqdm_bar.update(num_interp)
used_multi = True
except model_management.OOM_EXCEPTION:
model_management.soft_empty_cache()
multi_fn = None # fall through to single-timestep path
if not used_multi:
j = 0
while j < num_interp:
b = min(batch, num_interp - j)
try:
img0 = img0_single.expand(b, -1, -1, -1)
img1 = img1_single.expand(b, -1, -1, -1)
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
out_idx += b
pbar.update(b)
tqdm_bar.update(b)
j += b
except model_management.OOM_EXCEPTION:
if batch <= 1:
raise
batch = max(1, batch // 2)
model_management.soft_empty_cache()
result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
out_idx += 1
finally:
tqdm_bar.close()
# BCHW -> BHWC
result = result.movedim(1, -1).clamp_(0.0, 1.0)
return io.NodeOutput(result)
class FrameInterpolationExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FrameInterpolationModelLoader,
FrameInterpolate,
]
async def comfy_entrypoint() -> FrameInterpolationExtension:
return FrameInterpolationExtension()

View File

@ -1,6 +1,7 @@
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
import torchaudio
import comfy.model_management import comfy.model_management
import comfy.model_sampling import comfy.model_sampling
import comfy.samplers import comfy.samplers
@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify # Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio) sample_rate = reference_audio["sample_rate"]
vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = reference_audio["waveform"]
audio_latents = audio_vae.encode(waveform.movedim(1, -1))
b, c, t, f = audio_latents.shape b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens} ref_audio = {"tokens": ref_tokens}

View File

@ -3,9 +3,8 @@ import comfy.utils
import comfy.model_management import comfy.model_management
import torch import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_audio import VAEEncodeAudio
class LTXVAudioVAELoader(io.ComfyNode): class LTXVAudioVAELoader(io.ComfyNode):
@classmethod @classmethod
@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
def execute(cls, ckpt_name: str) -> io.NodeOutput: def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata)) sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return io.NodeOutput(vae)
class LTXVAudioVAEEncode(io.ComfyNode): class LTXVAudioVAEEncode(VAEEncodeAudio):
@classmethod @classmethod
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: def execute(cls, audio, audio_vae) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio) return super().execute(audio_vae, audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode): class LTXVAudioVAEDecode(io.ComfyNode):
@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: def execute(cls, samples, audio_vae) -> io.NodeOutput:
audio_latent = samples["samples"] audio_latent = samples["samples"]
if audio_latent.is_nested: if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1] audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device) audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
return io.NodeOutput( return io.NodeOutput(
{ {
"waveform": audio, "waveform": audio,
@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
frames_number: int, frames_number: int,
frame_rate: int, frame_rate: int,
batch_size: int, batch_size: int,
audio_vae: AudioVAE, audio_vae,
) -> io.NodeOutput: ) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure.""" """Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required" assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins audio_freq = audio_vae.first_stage_model.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate) sampling_rate = int(audio_vae.first_stage_model.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros( audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq), (batch_size, z_channels, num_audio_latents, audio_freq),

View File

@ -7,7 +7,10 @@ import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.latent_formats import comfy.latent_formats
import comfy.ldm.lumina.controlnet import comfy.ldm.lumina.controlnet
import comfy.ldm.supir.supir_modules
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
from comfy_api.latest import io
from comfy.ldm.supir.supir_patch import SUPIRPatch
class BlockWiseControlBlock(torch.nn.Module): class BlockWiseControlBlock(torch.nn.Module):
@ -266,6 +269,27 @@ class ModelPatchLoader:
out_dim=sd["audio_proj.norm.weight"].shape[0], out_dim=sd["audio_proj.norm.weight"].shape[0],
device=comfy.model_management.unet_offload_device(), device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast) operations=comfy.ops.manual_cast)
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
prefix_replace = {}
if 'model.control_model.input_hint_block.0.weight' in sd:
prefix_replace["model.control_model."] = "control_model."
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
else:
prefix_replace["control_model."] = "control_model."
prefix_replace["project_modules."] = "project_modules."
# Extract denoise_encoder weights before filter_keys discards them
de_prefix = "first_stage_model.denoise_encoder."
denoise_encoder_sd = {}
for k in list(sd.keys()):
if k.startswith(de_prefix):
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
sd.pop("control_model.mask_LQ", None)
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if denoise_encoder_sd:
model.denoise_encoder_sd = denoise_encoder_sd
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model.load_state_dict(sd, assign=model_patcher.is_dynamic()) model.load_state_dict(sd, assign=model_patcher.is_dynamic())
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
) )
class SUPIRApply(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SUPIRApply",
category="model_patches/supir",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
tooltip="Control strength at the start of sampling (high sigma)."),
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="Sigma threshold below which restore_cfg is disabled."),
],
outputs=[io.Model.Output()],
)
@classmethod
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
if not denoise_sd:
return vae.encode(image)
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
orig_patcher = vae.patcher
vae.patcher = orig_patcher.clone()
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
try:
return vae.encode(image)
finally:
vae.patcher = orig_patcher
@classmethod
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
model_patched = model.clone()
hint_latent = model.get_model_object("latent_format").process_in(
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
patch.register(model_patched)
if restore_cfg > 0.0:
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
latent_format = model.get_model_object("latent_format")
decoded = vae.decode(latent_format.process_out(hint_latent))
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
sigma_max = 14.6146
def restore_cfg_function(args):
denoised = args["denoised"]
sigma = args["sigma"]
if sigma.dim() > 0:
s = sigma[0].item()
else:
s = sigma.item()
if s > restore_cfg_s_tmin:
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
b = denoised.shape[0]
if ref.shape[0] != b:
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
d_center = denoised - ref
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
return denoised
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
return io.NodeOutput(model_patched)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader, "ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet, "ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference, "USOStyleReference": USOStyleReference,
"SUPIRApply": SUPIRApply,
} }

View File

@ -6,6 +6,7 @@ from PIL import Image
import math import math
from enum import Enum from enum import Enum
from typing import TypedDict, Literal from typing import TypedDict, Literal
import kornia
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
return io.NodeOutput(batched) return io.NodeOutput(batched)
class ColorTransfer(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorTransfer",
category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
options=[
io.DynamicCombo.Option("per_frame", []),
io.DynamicCombo.Option("uniform", []),
io.DynamicCombo.Option("target_frame", [
io.Int.Input("target_index", default=0, min=0, max=10000,
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
]),
]),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@staticmethod
def _to_lab(images, i, device):
return kornia.color.rgb_to_lab(
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
@staticmethod
def _pool_stats(images, device, is_reinhard, eps):
"""Two-pass pooled mean + std/cov across all frames."""
N, C = images.shape[0], images.shape[3]
HW = images.shape[1] * images.shape[2]
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
for i in range(N):
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
mean /= N
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
for i in range(N):
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
if is_reinhard:
acc += (centered * centered).mean(dim=-1, keepdim=True)
else:
acc += centered @ centered.T / HW
if is_reinhard:
return mean, torch.sqrt(acc / N).clamp_min_(eps)
return mean, acc / N
@staticmethod
def _frame_stats(lab_flat, hw, is_reinhard, eps):
"""Per-frame mean + std/cov."""
mean = lab_flat.mean(dim=-1, keepdim=True)
if is_reinhard:
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
centered = lab_flat - mean
return mean, centered @ centered.T / hw
@staticmethod
def _mkl_matrix(cov_s, cov_r, eps):
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
mid = scaled_V.T @ cov_r @ scaled_V
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
inv_sqrt_s = 1.0 / sqrt_val_s
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
return inv_scaled_V @ M_half @ inv_scaled_V.T
@staticmethod
def _histogram_lut(src, ref, bins=256):
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
ones_s = torch.ones_like(src)
ones_r = torch.ones_like(ref)
s_hist.scatter_add_(1, s_bins, ones_s)
r_hist.scatter_add_(1, r_bins, ones_r)
s_cdf = s_hist.cumsum(1)
s_cdf = s_cdf / s_cdf[:, -1:]
r_cdf = r_hist.cumsum(1)
r_cdf = r_cdf / r_cdf[:, -1:]
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
@classmethod
def _pooled_cdf(cls, images, device, num_bins=256):
"""Build pooled CDF across all frames, one frame at a time."""
C = images.shape[3]
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
for i in range(images.shape[0]):
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
hist.scatter_add_(1, bins, torch.ones_like(frame))
cdf = hist.cumsum(1)
return cdf / cdf[:, -1:]
@classmethod
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
"""Build per-frame or uniform LUT transform for histogram mode."""
if stats_mode == 'per_frame':
return None # LUT computed per-frame in the apply loop
r_cdf = cls._pooled_cdf(image_ref, device)
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
else:
s_cdf = cls._pooled_cdf(image_target, device)
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
@classmethod
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
"""Build transform parameters for Lab-based methods. Returns a transform function."""
eps = 1e-6
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
single_ref = B_ref == 1
HW = H * W
HW_ref = image_ref.shape[1] * image_ref.shape[2]
# Precompute ref stats
if single_ref or stats_mode in ('uniform', 'target_frame'):
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
# Uniform/target_frame: precompute single affine transform
if stats_mode in ('uniform', 'target_frame'):
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
else:
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
if is_reinhard:
scale = ref_sc / s_sc
offset = ref_mean - scale * s_mean
return lambda src_flat, **_: src_flat * scale + offset
T = cls._mkl_matrix(s_sc, ref_sc, eps)
offset = ref_mean - T @ s_mean
return lambda src_flat, **_: T @ src_flat + offset
# per_frame
def per_frame_transform(src_flat, frame_idx):
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
if single_ref:
r_mean, r_sc = ref_mean, ref_sc
else:
ri = min(frame_idx, B_ref - 1)
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
centered = src_flat - s_mean
if is_reinhard:
return centered * (r_sc / s_sc) + r_mean
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
return T @ centered + r_mean
return per_frame_transform
@classmethod
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
stats_mode = source_stats["source_stats"]
target_index = source_stats.get("target_index", 0)
if strength == 0 or image_ref is None:
return io.NodeOutput(image_target)
device = comfy.model_management.get_torch_device()
intermediate_device = comfy.model_management.intermediate_device()
intermediate_dtype = comfy.model_management.intermediate_dtype()
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
pbar = comfy.utils.ProgressBar(B)
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
if method == 'histogram':
uniform_lut = cls._build_histogram_transform(
image_target, image_ref, device, stats_mode, target_index, B)
for i in range(B):
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
src_flat = src.reshape(C, -1)
if uniform_lut is not None:
lut = uniform_lut
else:
ri = min(i, B_ref - 1)
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
lut = cls._histogram_lut(src_flat, ref)
bin_idx = (src_flat * 255).long().clamp(0, 255)
matched = lut.gather(1, bin_idx).view(C, H, W)
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
else:
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
for i in range(B):
src_frame = cls._to_lab(image_target, i, device)
corrected = transform(src_frame.view(C, -1), frame_idx=i)
if strength == 1.0:
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
else:
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
return io.NodeOutput(out)
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
BatchImagesNode, BatchImagesNode,
BatchMasksNode, BatchMasksNode,
BatchLatentsNode, BatchLatentsNode,
ColorTransfer,
# BatchImagesMasksLatentsNode, # BatchImagesMasksLatentsNode,
] ]

View File

@ -11,7 +11,7 @@ class PreviewAny():
"required": {"source": (IO.ANY, {})}, "required": {"source": (IO.ANY, {})},
} }
RETURN_TYPES = () RETURN_TYPES = (IO.STRING,)
FUNCTION = "main" FUNCTION = "main"
OUTPUT_NODE = True OUTPUT_NODE = True
@ -33,7 +33,7 @@ class PreviewAny():
except Exception: except Exception:
value = 'source exists, but could not be serialized.' value = 'source exists, but could not be serialized.'
return {"ui": {"text": (value,)}} return {"ui": {"text": (value,)}, "result": (value,)}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"PreviewAny": PreviewAny, "PreviewAny": PreviewAny,

View File

@ -32,10 +32,12 @@ class RTDETR_detect(io.ComfyNode):
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput: def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
B, H, W, C = image.shape B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
comfy.model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts results = []
for i in range(0, B, 32):
batch = image[i:i + 32]
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
results.extend(model.model.diffusion_model(image_in, (W, H)))
all_bbox_dicts = [] all_bbox_dicts = []

529
comfy_extras/nodes_sam3.py Normal file
View File

@ -0,0 +1,529 @@
"""
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
"""
from typing_extensions import override
import json
import os
import torch
import torch.nn.functional as F
import comfy.model_management
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io, ui
import av
from fractions import Fraction
def _extract_text_prompts(conditioning, device, dtype):
"""Extract list of (text_embeddings, text_mask) from conditioning."""
cond_meta = conditioning[0][1]
multi = cond_meta.get("sam3_multi_cond")
prompts = []
if multi is not None:
for entry in multi:
emb = entry["cond"].to(device=device, dtype=dtype)
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
if mask is None:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, entry.get("max_detections", 1)))
else:
emb = conditioning[0][0].to(device=device, dtype=dtype)
mask = cond_meta.get("attention_mask")
if mask is not None:
mask = mask.to(device)
else:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, 1))
return prompts
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
Returns: [1, H, W] binary mask
"""
def _coarse_fallback():
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
mode="bilinear", align_corners=False)[0] > 0).float()
if iterations <= 0:
return _coarse_fallback()
pad_frac = 0.1
x1, y1, x2, y2 = box_xyxy.tolist()
bw, bh = x2 - x1, y2 - y1
cx1 = max(0, int(x1 - bw * pad_frac))
cy1 = max(0, int(y1 - bh * pad_frac))
cx2 = min(W, int(x2 + bw * pad_frac))
cy2 = min(H, int(y2 + bh * pad_frac))
if cx2 <= cx1 or cy2 <= cy1:
return _coarse_fallback()
crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3]
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
crop_frame = crop_1008.to(device=device, dtype=dtype)
crop_h, crop_w = cy2 - cy1, cx2 - cx1
# Crop coarse mask and refine via SAM on the cropped image
mask_h, mask_w = coarse_mask.shape[-2:]
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
if mx2 <= mx1 or my2 <= my1:
return _coarse_fallback()
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
for _ in range(iterations):
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
class SAM3_Detect(io.ComfyNode):
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_Detect",
display_name="SAM3 Detect",
category="detection/",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[
io.Model.Input("model", display_name="model"),
io.Image.Input("image", display_name="image"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
],
outputs=[
io.Mask.Output("masks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
def _boxes_to_tensor(box_list):
coords = []
for d in box_list:
cx = (d["x"] + d["width"] / 2) / W
cy = (d["y"] + d["height"] / 2) / H
coords.append([cx, cy, d["width"] / W, d["height"] / H])
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
per_frame_boxes = None
if bboxes is not None:
if isinstance(bboxes, dict):
# Single box → same for all frames
shared = _boxes_to_tensor([bboxes])
per_frame_boxes = [shared] * B
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
# list[list[dict]] → per-frame boxes
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
# Pad to B if fewer frames provided
while len(per_frame_boxes) < B:
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
elif isinstance(bboxes, list) and len(bboxes) > 0:
# list[dict] → same boxes for all frames
shared = _boxes_to_tensor(bboxes)
per_frame_boxes = [shared] * B
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
pos_pts = json.loads(positive_coords) if positive_coords else []
neg_pts = json.loads(negative_coords) if negative_coords else []
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
# Build point inputs for tracker SAM decoder path
point_inputs = None
if has_points:
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
point_inputs = {
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
}
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
has_text = len(cond_list) > 0
# Run per-image through detector (text/boxes) and/or tracker (points)
all_bbox_dicts = []
all_masks = []
pbar = comfy.utils.ProgressBar(B)
for b in range(B):
frame = image_in[b:b+1].to(device=device, dtype=dtype)
b_boxes = None
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
frame_bbox_dicts = []
frame_masks = []
# Point prompts: tracker SAM decoder path with iterative refinement
if point_inputs is not None:
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Box prompts: SAM decoder path (segment inside each box)
if b_boxes is not None and not has_text:
for box_cxcywh in b_boxes[0]:
cx, cy, bw, bh = box_cxcywh.tolist()
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
device=device, dtype=dtype)
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Text prompts: run detector per text prompt (each detects one category)
for text_embeddings, text_mask, max_det in cond_list:
results = sam3_model(
frame, text_embeddings=text_embeddings, text_mask=text_mask,
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
pred_boxes = results["boxes"][0]
scores = results["scores"][0]
masks = results["masks"][0]
probs = scores.sigmoid()
keep = probs > threshold
kept_boxes = pred_boxes[keep].cpu()
kept_scores = probs[keep].cpu()
kept_masks = masks[keep]
order = kept_scores.argsort(descending=True)[:max_det]
kept_boxes = kept_boxes[order]
kept_scores = kept_scores[order]
kept_masks = kept_masks[order]
for box, score in zip(kept_boxes, kept_scores):
frame_bbox_dicts.append({
"x": float(box[0]), "y": float(box[1]),
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
"score": float(score),
})
for m, box in zip(kept_masks, kept_boxes):
frame_masks.append(_refine_mask(
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
all_bbox_dicts.append(frame_bbox_dicts)
if len(frame_masks) > 0:
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
if individual_masks:
all_masks.append(combined)
else:
all_masks.append((combined > 0).any(dim=0).float())
else:
if individual_masks:
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
else:
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
pbar.update(1)
idev = comfy.model_management.intermediate_device()
all_masks = [m.to(idev) for m in all_masks]
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
return io.NodeOutput(mask_out, all_bbox_dicts)
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
class SAM3_VideoTrack(io.ComfyNode):
"""Track objects across video frames using SAM3's memory-based tracker."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track",
category="detection/",
search_aliases=["sam3", "video", "track", "propagate"],
inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
SAM3TrackData.Output("track_data", display_name="track_data"),
],
)
@classmethod
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
N, H, W, C = images.shape
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
init_masks = None
if initial_mask is not None:
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(N)
text_prompts = None
if conditioning is not None and len(conditioning) > 0:
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
elif initial_mask is None:
raise ValueError("Either initial_mask or conditioning must be provided")
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
class SAM3_TrackPreview(io.ComfyNode):
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True),
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
],
is_output_node=True,
)
COLORS = [
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
]
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
@staticmethod
def _get_glyphs(device, scale=3):
key = (device, scale)
if key in SAM3_TrackPreview._glyph_cache:
return SAM3_TrackPreview._glyph_cache[key]
atlas = torch.tensor([
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
], dtype=torch.bool)
glyphs, outlines = [], []
for d in range(10):
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
glyphs.append(g.to(device))
outlines.append(o.to(device))
gh, gw = glyphs[0].shape
oh, ow = outlines[0].shape
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
return SAM3_TrackPreview._glyph_cache[key]
@staticmethod
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
H, W = frame.shape[:2]
device = frame.device
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
digs = [int(d) for d in str(number)]
total_w = len(digs) * (gw + scale) - scale
x0 = cx - total_w // 2
y0 = cy - gh // 2
for i, d in enumerate(digs):
dx = x0 + i * (gw + scale)
# Black outline
oy0, ox0 = y0 - 1, dx - 1
osy1, osx1 = max(0, -oy0), max(0, -ox0)
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
if osy2 > osy1 and osx2 > osx1:
fy1, fx1 = oy0 + osy1, ox0 + osx1
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
# Colored fill
sy1, sx1 = max(0, -y0), max(0, -dx)
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
if sy2 > sy1 and sx2 > sx1:
fy1, fx1 = y0 + sy1, dx + sx1
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
@classmethod
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if images is not None:
H, W = images.shape[1], images.shape[2]
if packed is None:
N, N_obj = track_data["n_frames"], 0
else:
N, N_obj = packed.shape[0], packed.shape[1]
import uuid
gpu = comfy.model_management.get_torch_device()
temp_dir = folder_paths.get_temp_directory()
filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4"
filepath = os.path.join(temp_dir, filename)
with av.open(filepath, mode='w') as output:
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
stream.width = W
stream.height = H
stream.pix_fmt = 'yuv420p'
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
frame_np = frame_cpu.numpy()
if N_obj > 0:
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
device=gpu, dtype=torch.float32)
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
for t in range(N):
if images is not None and t < images.shape[0]:
frame = images[t].clone()
else:
frame = torch.zeros(H, W, 3)
if N_obj > 0:
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
frame_gpu = frame.to(gpu)
bool_masks = frame_masks > 0.5
any_mask = bool_masks.any(dim=0)
if any_mask.any():
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
color_overlay = colors_t[obj_idx_map]
mask_3d = any_mask.unsqueeze(-1)
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
output.mux(stream.encode(None))
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
class SAM3_TrackToMask(io.ComfyNode):
"""Select tracked objects by index and output as mask."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="",
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
],
outputs=[
io.Mask.Output("masks", display_name="masks"),
],
)
@classmethod
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if packed is None:
N = track_data["n_frames"]
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
N, N_obj = packed.shape[0], packed.shape[1]
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
indices = [i for i in indices if 0 <= i < N_obj]
else:
indices = list(range(N_obj))
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)
class SAM3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SAM3_Detect,
SAM3_VideoTrack,
SAM3_TrackPreview,
SAM3_TrackToMask,
]
async def comfy_entrypoint() -> SAM3Extension:
return SAM3Extension()

View File

@ -1,5 +1,6 @@
import torch import torch
import comfy.utils import comfy.utils
import comfy.model_management
import numpy as np import numpy as np
import math import math
import colorsys import colorsys
@ -410,7 +411,9 @@ class SDPoseDrawKeypoints(io.ComfyNode):
pose_outputs.append(canvas) pose_outputs.append(canvas)
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 final_pose_output = torch.from_numpy(pose_outputs_np).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype()) / 255.0
return io.NodeOutput(final_pose_output) return io.NodeOutput(final_pose_output)
class SDPoseKeypointExtractor(io.ComfyNode): class SDPoseKeypointExtractor(io.ComfyNode):
@ -459,6 +462,27 @@ class SDPoseKeypointExtractor(io.ComfyNode):
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
def _resize_to_model(imgs):
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
h, w = imgs.shape[-3], imgs.shape[-2]
scale = min(model_h / h, model_w / w)
sh, sw = int(round(h * scale)), int(round(w * scale))
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
chw = imgs.permute(0, 3, 1, 2).float()
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
return padded.permute(0, 2, 3, 1), scale, pt, pl
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
"""Remap keypoints from model space back to original image space."""
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
invalid = kp[..., 0] < 0
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
kp[invalid] = -1
return kp
def _run_on_latent(latent_batch): def _run_on_latent(latent_batch):
"""Run one forward pass and return (keypoints_list, scores_list) for the batch.""" """Run one forward pass and return (keypoints_list, scores_list) for the batch."""
nonlocal captured_feat nonlocal captured_feat
@ -504,36 +528,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
if x2 <= x1 or y2 <= y1: if x2 <= x1 or y2 <= y1:
continue continue
crop_h_px, crop_w_px = y2 - y1, x2 - x1
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
scale = min(model_h / crop_h_px, model_w / crop_w_px)
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
latent_crop = vae.encode(crop_resized) latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop) kp_batch, sc_batch = _run_on_latent(latent_crop)
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
# remove padding offset, undo scale, offset to full-image coordinates.
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
img_keypoints.append(kp) img_keypoints.append(kp)
img_scores.append(sc) img_scores.append(sc_batch[0])
else: else:
# No bboxes for this image run on the full image img_resized, scale, pad_top, pad_left = _resize_to_model(img)
latent_img = vae.encode(img) latent_img = vae.encode(img_resized)
kp_batch, sc_batch = _run_on_latent(latent_img) kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(kp_batch[0]) img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
img_scores.append(sc_batch[0]) img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints) all_keypoints.append(img_keypoints)
@ -541,19 +548,16 @@ class SDPoseKeypointExtractor(io.ComfyNode):
pbar.update(1) pbar.update(1)
else: # full-image mode, batched else: # full-image mode, batched
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
for batch_start in range(0, total_images, batch_size): batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
batch_end = min(batch_start + batch_size, total_images) latent_batch = vae.encode(batch_resized)
latent_batch = vae.encode(image[batch_start:batch_end])
kp_batch, sc_batch = _run_on_latent(latent_batch) kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch): for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([kp]) all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
all_scores.append([sc]) all_scores.append([sc])
tqdm_pbar.update(1)
pbar.update(batch_end - batch_start) pbar.update(len(kp_batch))
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
return io.NodeOutput(openpose_frames) return io.NodeOutput(openpose_frames)

View File

@ -1,4 +1,5 @@
import re import re
import json
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode):
return io.NodeOutput(result) return io.NodeOutput(result)
class JsonExtractString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="JsonExtractString",
display_name="Extract String from JSON",
category="utils/string",
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
inputs=[
io.String.Input("json_string", multiline=True),
io.String.Input("key", multiline=False),
],
outputs=[
io.String.Output(),
]
)
@classmethod
def execute(cls, json_string, key):
try:
data = json.loads(json_string)
if isinstance(data, dict) and key in data:
value = data[key]
if value is None:
return io.NodeOutput("")
return io.NodeOutput(str(value))
return io.NodeOutput("")
except (json.JSONDecodeError, TypeError):
return io.NodeOutput("")
class StringExtension(ComfyExtension): class StringExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -390,6 +424,7 @@ class StringExtension(ComfyExtension):
RegexMatch, RegexMatch,
RegexExtract, RegexExtract,
RegexReplace, RegexReplace,
JsonExtractString,
] ]
async def comfy_entrypoint() -> StringExtension: async def comfy_entrypoint() -> StringExtension:

View File

@ -35,6 +35,7 @@ class TextGenerate(io.ComfyNode):
io.Int.Input("max_length", default=256, min=1, max=2048), io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),
], ],
outputs=[ outputs=[
io.String.Output(display_name="generated_text"), io.String.Output(display_name="generated_text"),
@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking) tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
# Get sampling parameters from dynamic combo # Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on" do_sample = sampling_mode.get("sampling_mode") == "on"
@ -160,12 +161,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
if image is None: if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else: else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
class TextgenExtension(ComfyExtension): class TextgenExtension(ComfyExtension):

View File

@ -6,6 +6,7 @@ import comfy.utils
import folder_paths import folder_paths
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.model_management
try: try:
from spandrel_extra_arches import EXTRA_REGISTRY from spandrel_extra_arches import EXTRA_REGISTRY
@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode):
tile = 512 tile = 512
overlap = 32 overlap = 32
output_device = comfy.model_management.intermediate_device()
oom = True oom = True
try: try:
while oom: while oom:
try: try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
oom = False oom = False
except Exception as e: except Exception as e:
model_management.raise_non_oom(e) model_management.raise_non_oom(e)
@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
finally: finally:
upscale_model.to("cpu") upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
return io.NodeOutput(s) return io.NodeOutput(s)
upscale = execute # TODO: remove upscale = execute # TODO: remove

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.18.1" __version__ = "0.19.3"

View File

@ -811,11 +811,30 @@ class PromptExecutor:
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
if visiting is None:
visiting = []
unique_id = item unique_id = item
if unique_id in validated: if unique_id in validated:
return validated[unique_id] return validated[unique_id]
if unique_id in visiting:
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
for node_id in cycle_nodes:
validated[node_id] = (False, [{
"type": "dependency_cycle",
"message": "Dependency cycle detected",
"details": cycle_path,
"extra_info": {
"node_id": node_id,
"cycle_nodes": cycle_nodes,
}
}], node_id)
return validated[unique_id]
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
try: try:
r = await validate_inputs(prompt_id, prompt, o_id, validated) visiting.append(unique_id)
try:
r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting)
finally:
visiting.pop()
if r[0] is False: if r[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(errors) > 0 or valid is not True: ret = validated.get(unique_id, (True, [], unique_id))
ret = (False, errors, unique_id) # Recursive cycle detection may have already populated an error on us. Join it.
else: ret = (
ret = (True, [], unique_id) ret[0] and valid is True and not errors,
ret[1] + [error for error in errors if error not in ret[1]],
unique_id,
)
validated[unique_id] = ret validated[unique_id] = ret
return ret return ret

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")

View File

@ -9,6 +9,8 @@ import folder_paths
import time import time
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger from app.logger import setup_logger
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files from app.assets.services import register_output_files
import itertools import itertools
@ -27,8 +29,6 @@ if __name__ == "__main__":
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1' os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
faulthandler.enable(file=sys.stderr, all_threads=False) faulthandler.enable(file=sys.stderr, all_threads=False)
import comfy_aimdo.control import comfy_aimdo.control

View File

@ -1 +1 @@
comfyui_manager==4.1 comfyui_manager==4.2.1

View File

@ -2457,7 +2457,9 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py", "nodes_number_convert.py",
"nodes_painter.py", "nodes_painter.py",
"nodes_curve.py", "nodes_curve.py",
"nodes_rtdetr.py" "nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py"
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.18.1" version = "0.19.3"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.42.8 comfyui-frontend-package==1.42.14
comfyui-workflow-templates==0.9.44 comfyui-workflow-templates==0.9.61
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.4
torch torch
torchsde torchsde
torchvision torchvision
@ -19,11 +19,11 @@ scipy
tqdm tqdm
psutil psutil
alembic alembic
SQLAlchemy SQLAlchemy>=2.0
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.12 comfy-aimdo==0.2.14
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3

View File

@ -39,7 +39,7 @@ def get_required_packages_versions():
if len(s) == 2: if len(s) == 2:
version_str = s[-1] version_str = s[-1]
if not is_valid_version(version_str): if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}") logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}")
continue continue
out[s[0]] = version_str out[s[0]] = version_str
return out.copy() return out.copy()