Merge branch 'master' into add-openapi-spec

pull/13397/head
guill 2026-04-23 20:52:18 -07:00 committed by GitHub
commit 9f88368030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 4295 additions and 86 deletions

View File

@ -2,7 +2,6 @@
precision mediump float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint
uniform float u_float0; // Intensity
@ -75,7 +74,7 @@ void main() {
float t0 = threshold - 0.15;
float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution;
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);

View File

@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
@ -25,7 +24,7 @@ float gaussian(float x, float sigma) {
}
void main() {
vec2 texelSize = 1.0 / u_resolution;
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable

View File

@ -2,14 +2,13 @@
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
vec2 texel = 1.0 / u_resolution;
vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
// Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord);

View File

@ -2,7 +2,6 @@
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
@ -19,7 +18,7 @@ float getLuminance(vec3 color) {
}
void main() {
vec2 texel = 1.0 / u_resolution;
vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float1, 0.5);
float amount = u_float0;
float threshold = u_float2;

View File

@ -268,7 +268,7 @@
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
"from_input"
]
},

View File

@ -331,7 +331,7 @@
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
"from_input"
]
}

View File

@ -267,7 +267,7 @@
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
"from_input"
]
}

View File

@ -383,7 +383,7 @@
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
"from_input"
]
}

596
comfy/ldm/sam3/detector.py Normal file
View File

@ -0,0 +1,596 @@
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import roi_align
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
from comfy.ops import cast_to_input
def box_cxcywh_to_xyxy(x):
cx, cy, w, h = x.unbind(-1)
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
def gen_sineembed_for_position(pos_tensor, num_feats=256):
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
assert num_feats % 2 == 0
hdim = num_feats // 2
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
embeds = []
for c in range(pos_tensor.shape[-1]):
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
class SplitMHA(nn.Module):
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def forward(self, q_input, k_input=None, v_input=None, mask=None):
q = self.q_proj(q_input)
if k_input is None:
k = self.k_proj(q_input)
v = self.v_proj(q_input)
else:
k = self.k_proj(k_input)
v = self.v_proj(v_input if v_input is not None else k_input)
if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
return self.out_proj(out)
class MLPWithNorm(nn.Module):
"""MLP with residual connection and output LayerNorm."""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
])
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
self.residual = residual and (input_dim == output_dim)
def forward(self, x):
orig = x
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
if self.residual:
x = x + orig
return self.out_norm(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x, pos, text_memory=None, text_mask=None):
normed = self.norm1(x)
q_k = normed + pos
x = x + self.self_attn(q_k, q_k, normed)
if text_memory is not None:
normed = self.norm2(x)
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
normed = self.norm3(x)
x = x + self.linear2(F.relu(self.linear1(normed)))
return x
class TransformerEncoder(nn.Module):
"""Checkpoint: transformer.encoder.layers.N.*"""
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, x, pos, text_memory=None, text_mask=None):
for layer in self.layers:
x = layer(x, pos, text_memory, text_mask)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
q_k = x + x_pos
x = self.norm2(x + self.self_attn(q_k, q_k, x))
if text_memory is not None:
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
return x
class TransformerDecoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.num_queries = num_queries
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
@staticmethod
def _inverse_sigmoid(x):
return torch.log(x / (1 - x + 1e-6) + 1e-6)
def _compute_box_rpb(self, ref_points, H, W):
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
B, Q, _ = boxes_xyxy.shape
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
log2_8 = float(math.log2(8))
def log_scale(d):
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
return torch.cat([pres_bias, bias], dim=2)
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
B = memory.shape[0]
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
for layer_idx, layer in enumerate(self.layers):
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
if layer_idx < len(self.layers) - 1:
ref_inv = self._inverse_sigmoid(ref_points)
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
query_out = self.norm(tgt)
ref_inv = self._inverse_sigmoid(ref_points)
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
class Transformer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
class GeometryEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.roi_size = roi_size
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.encode = nn.ModuleList([
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def _encode_points(self, coords, labels, img_feat_2d):
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
B, N, _ = coords.shape
embed = self.points_direct_project(coords)
# Pool features from backbone at point locations via grid_sample
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
# Positional encoding of coordinates
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def _encode_boxes(self, boxes, labels, img_feat_2d):
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
B, N, _ = boxes.shape
embed = self.boxes_direct_project(boxes)
# ROI align from backbone at box regions
H, W = img_feat_2d.shape[-2:]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
boxes_scaled = boxes_xyxy * scale
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
embed = embed + proj
# Positional encoding of box center + size
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
enc = enc.view(B, N, -1)
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def forward(self, points=None, boxes=None, image_features=None):
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
# Prepare 2D image features for pooling
img_feat_2d = None
if image_features is not None:
B = image_features.shape[0]
HW, C = image_features.shape[1], image_features.shape[2]
hw = int(math.sqrt(HW))
img_normed = self.img_pre_norm(image_features)
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
embeddings = []
if points is not None:
coords, labels = points
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
if boxes is not None:
B = boxes.shape[0]
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
if not embeddings:
return None
geo = torch.cat(embeddings, dim=1)
geo = self.norm(geo)
if image_features is not None:
for layer in self.encode:
geo = layer(geo, torch.zeros_like(geo), image_features)
geo = self.encode_norm(geo)
return self.final_proj(geo)
class PixelDecoder(nn.Module):
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
def forward(self, backbone_features):
prev = backbone_features[-1]
for i, feat in enumerate(backbone_features[:-1][::-1]):
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
return prev
class MaskPredictor(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
def forward(self, query_embeddings, pixel_features):
mask_embed = self.mask_embed(query_embeddings)
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
class SegmentationHead(nn.Module):
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
if encoder_hidden_states is not None and prompt is not None:
enc_normed = self.cross_attn_norm(encoder_hidden_states)
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
encoder_hidden_states = enc_cross + encoder_hidden_states
if encoder_hidden_states is not None:
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
backbone_features = list(backbone_features)
backbone_features[-1] = encoder_visual
pixel_features = self.pixel_decoder(backbone_features)
instance_features = self.instance_seg_head(pixel_features)
masks = self.mask_predictor(query_embeddings, instance_features)
return masks
class DotProductScoring(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
self.scale = 1.0 / (d_model ** 0.5)
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
prompt = self.prompt_mlp(prompt_embeddings)
if prompt_mask is not None:
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
else:
pooled = prompt.mean(dim=1)
hs = self.hs_proj(query_embeddings)
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
scores = torch.matmul(hs, pp)
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
class SAM3Detector(nn.Module):
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
image_model = kwargs.pop("image_model", "SAM3")
for k in ("num_heads", "num_head_channels"):
kwargs.pop(k, None)
multiplex = image_model == "SAM31"
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
self.scalp = 0 if multiplex else 1
self.backbone = nn.ModuleDict({
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
})
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
def _get_backbone_features(self, images):
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
bb = self.backbone["vision_backbone"]
if bb.multiplex:
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
else:
all_f, all_p, tf, tp = bb(images, need_tracker=True)
return all_f, all_p, tf, tp
@staticmethod
def _run_geo_layer(layer, x, memory, memory_pos):
x = x + layer.self_attn(layer.norm1(x))
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
return x
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
points=None, boxes=None):
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
B = features[0].shape[0]
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
seg_features = features
if self.scalp > 0:
features = features[:-self.scalp]
positions = positions[:-self.scalp]
enc_feat, enc_pos = features[-1], positions[-1]
_, _, H, W = enc_feat.shape
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
has_prompts = text_embeddings is not None or points is not None or boxes is not None
if has_prompts:
geo_enc = self.geometry_encoder
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
for layer in geo_enc.encode:
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
geo_cls = geo_enc.encode_norm(geo_cls)
if text_embeddings is not None and text_embeddings.shape[0] != B:
text_embeddings = text_embeddings.expand(B, -1, -1)
if text_mask is not None and text_mask.shape[0] != B:
text_mask = text_mask.expand(B, -1)
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
text_embeddings = torch.cat(parts, dim=1)
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
if text_mask is not None:
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
else:
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
if text_embeddings is not None:
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
else:
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
features, positions, _, _ = self._get_backbone_features(images)
if text_embeddings is not None:
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, dec_out = self._detect(
features, positions, text_embeddings, text_mask, points, boxes)
if orig_size is not None:
oh, ow = orig_size
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
return {
"boxes": boxes_xyxy,
"scores": scores,
"masks": masks,
"presence": dec_out.get("presence"),
}
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
"""Run detection using a pre-computed ViTDet trunk output.
text_embeddings must already be resized through language_backbone.resizer.
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
"""
bb = self.backbone["vision_backbone"]
features = [conv(trunk_out) for conv in bb.convs]
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
class SAM3Model(nn.Module):
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
image_model = kwargs.get("image_model", "SAM3")
tracker_cls = TRACKER_CLASSES[image_model]
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
def forward(self, images, **kwargs):
return self.detector(images, **kwargs)
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
Args:
images: [B, 3, 1008, 1008] preprocessed images
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
mask_inputs: [B, 1, H, W] coarse mask logits to refine
Returns:
[B, 1, image_size, image_size] high-res mask logits
"""
bb = self.detector.backbone["vision_backbone"]
if bb.multiplex:
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
else:
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
if self.detector.scalp > 0:
tracker_features = tracker_features[:-self.detector.scalp]
tracker_positions = tracker_positions[:-self.detector.scalp]
high_res = list(tracker_features[:-1])
backbone_feat = tracker_features[-1]
B, C, H, W = backbone_feat.shape
# Add no-memory embedding (init frame path)
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
if no_mem is None:
no_mem = getattr(self.tracker, 'no_mem_embed', None)
if no_mem is not None:
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
backbone_features=backbone_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
box_inputs=box_inputs,
high_res_features=high_res,
multimask_output=(0 < num_pts <= 1),
)
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
def backbone_fn(frame, frame_idx=None):
trunk_out = bb.trunk(frame)
if bb.multiplex:
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
else:
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
return tf, tp, trunk_out
detect_fn = None
if text_prompts:
resizer = self.detector.backbone["language_backbone"]["resizer"]
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
def detect_fn(trunk_out):
all_scores, all_masks = [], []
for emb, mask in resized:
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
all_scores.append(det["scores"])
all_masks.append(det["masks"])
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
if hasattr(self.tracker, 'track_video_with_detection'):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)

425
comfy/ldm/sam3/sam.py Normal file
View File

@ -0,0 +1,425 @@
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.layers import EmbedND
from comfy.ops import cast_to_input
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
return torch.sigmoid(x) if self.sigmoid_output else x
class SAMAttention(nn.Module):
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
internal_dim = embedding_dim // downsample_rate
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module):
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.skip_first_layer_pe = skip_first_layer_pe
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, queries, keys, query_pe, key_pe):
if self.skip_first_layer_pe:
queries = self.norm1(self.self_attn(queries, queries, queries))
else:
q = queries + query_pe
queries = self.norm1(queries + self.self_attn(q, q, queries))
q, k = queries + query_pe, keys + key_pe
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
queries = self.norm3(queries + self.mlp(queries))
q, k = queries + query_pe, keys + key_pe
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
return queries, keys
class TwoWayTransformer(nn.Module):
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
for i in range(depth)
])
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, image_embedding, image_pe, point_embedding):
queries, keys = point_embedding, image_embedding
for layer in self.layers:
queries, keys = layer(queries, keys, point_embedding, image_pe)
q, k = queries + point_embedding, keys + image_pe
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
return queries, keys
class PositionEmbeddingRandom(nn.Module):
"""Fourier feature positional encoding with random gaussian projection."""
def __init__(self, num_pos_feats=64, scale=None):
super().__init__()
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
def _encode(self, normalized_coords):
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
orig_dtype = normalized_coords.dtype
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
def forward(self, size, device=None):
h, w = size
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
def forward_with_coords(self, pixel_coords, image_size):
norm = pixel_coords.clone()
norm[:, :, 0] /= image_size[1]
norm[:, :, 1] /= image_size[0]
return self._encode(norm)
# ViTDet backbone + FPN neck
def window_partition(x: torch.Tensor, window_size: int):
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
ids = torch.stack([(t % end_x) * scale_pos,
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
class _ViTMLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
"""ViTDet multi-head attention with fused QKV projection."""
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.use_rope = use_rope
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs_cis=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
def forward(self, x, freqs_cis=None):
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
B, H, W, C = x.shape
x = x.view(B, H * W, C)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(B, H, W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
super().__init__()
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.proj(x)
class ViTDet(nn.Module):
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.global_att_blocks = set(global_att_blocks)
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
grid_size = img_size // patch_size
pretrain_grid = pretrain_img_size // patch_size
self.blocks = nn.ModuleList()
for i in range(depth):
is_global = i in self.global_att_blocks
self.blocks.append(Block(
embed_dim, num_heads, mlp_ratio, qkv_bias,
window_size=0 if is_global else window_size,
use_rope=use_rope,
device=device, dtype=dtype, operations=operations,
))
if use_rope:
rope_scale = pretrain_grid / grid_size
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
else:
self.freqs_cis = None
self.freqs_cis_window = None
def _get_pos_embed(self, num_tokens):
pos = self.pos_embed
if pos.shape[1] == num_tokens:
return pos
cls_pos = pos[:, :1]
spatial_pos = pos[:, 1:]
old_size = int(math.sqrt(spatial_pos.shape[1]))
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
tiles_h = new_size // old_size + 1
tiles_w = new_size // old_size + 1
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
return torch.cat([cls_pos, tiled], dim=1)
def forward(self, x):
x = self.patch_embed(x)
B, C, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
x = x + pos[:, 1:Hp * Wp + 1]
x = x.view(B, Hp, Wp, C)
x = self.ln_pre(x)
freqs_cis_global = self.freqs_cis
freqs_cis_win = self.freqs_cis_window
if freqs_cis_global is not None:
freqs_cis_global = cast_to_input(freqs_cis_global, x)
if freqs_cis_win is not None:
freqs_cis_win = cast_to_input(freqs_cis_win, x)
for block in self.blocks:
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
x = block(x, freqs_cis=fc)
return x.permute(0, 3, 1, 2)
class FPNScaleConv(nn.Module):
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
super().__init__()
if scale == 4.0:
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 4
elif scale == 2.0:
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 2
elif scale == 1.0:
proj_in = in_dim
elif scale == 0.5:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
proj_in = in_dim
self.scale = scale
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
def forward(self, x):
if self.scale == 4.0:
x = F.gelu(self.dconv_2x2_0(x))
x = self.dconv_2x2_1(x)
elif self.scale == 2.0:
x = self.dconv_2x2(x)
elif self.scale == 0.5:
x = self.pool(x)
x = self.conv_1x1(x)
x = self.conv_3x3(x)
return x
class PositionEmbeddingSine(nn.Module):
"""2D sinusoidal position encoding (DETR-style) with result caching."""
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
super().__init__()
assert num_pos_feats % 2 == 0
self.half_dim = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
self._cache = {}
def _sincos(self, vals):
"""Encode 1D values to interleaved sin/cos features."""
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
raw = vals[..., None] * self.scale / freqs
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
def _encode_xy(self, x, y):
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
pos_x = x[:, None] * self.scale / dim_t
pos_y = y[:, None] * self.scale / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, cx, cy, w, h):
"""Encode box center + size to [N, d_model+2] features."""
pos_x, pos_y = self._encode_xy(cx, cy)
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
def forward(self, x):
B, C, H, W = x.shape
key = (H, W, x.device)
if key not in self._cache:
gy = torch.arange(H, dtype=torch.float32, device=x.device)
gx = torch.arange(W, dtype=torch.float32, device=x.device)
if self.normalize:
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
return self._cache[key].expand(B, -1, -1, -1)
class SAM3VisionBackbone(nn.Module):
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.multiplex = multiplex
fpn_args = dict(device=device, dtype=dtype, operations=operations)
if multiplex:
scales = [4.0, 2.0, 1.0]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
else:
scales = [4.0, 2.0, 1.0, 0.5]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
if tracker_only:
# Skip detector FPN when only tracker features are needed (video tracking)
if self.multiplex:
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
else:
tracker_convs = self.sam2_convs
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return None, None, tracker_features, tracker_positions
features = [conv(backbone_out) for conv in self.convs]
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
if self.multiplex:
if tracker_mode == "propagation":
tracker_convs = self.propagation_convs
elif tracker_mode == "interactive":
tracker_convs = self.interactive_convs
else:
return features, positions, None, None
elif need_tracker:
tracker_convs = self.sam2_convs
else:
return features, positions, None, None
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return features, positions, tracker_features, tracker_positions

1785
comfy/ldm/sam3/tracker.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,7 @@ import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.model_management
import comfy.patcher_extension
@ -578,8 +579,8 @@ class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone())
self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone())
def extra_conds(self, **kwargs):
out = {}
@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)

View File

@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "ernie"
return dit_config
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "SAM3"
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
dit_config["image_model"] = "SAM31"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config
def unet_prefix_from_state_dict(state_dict):
# SAM3: detector.* and tracker.* at top level, no common prefix
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
return ""
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
"net.", #cosmos

View File

@ -1801,7 +1801,7 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""
class InterruptProcessingException(Exception):
class InterruptProcessingException(BaseException):
pass
interrupt_processing_mutex = threading.RLock()

View File

@ -685,9 +685,9 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches:
if key not in self.patches and not force_cast:
return weight
inplace_update = self.weight_inplace_update or inplace_update
@ -695,7 +695,7 @@ class ModelPatcher:
if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
@ -703,9 +703,10 @@ class ModelPatcher:
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if key in self.patches:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight:
return out_weight
elif inplace_update:
@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
for param in params:
if param not in ("weight", "bias"):
force_load_param(self, param, device_to)
else:
for param in params:
key = key_param_name_to_key(n, param)

View File

@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
class SAM3(supported_models_base.BASE):
unet_config = {"image_model": "SAM3"}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
unet_extra_prefix = ""
def process_clip_state_dict(self, state_dict):
clip_keys = getattr(self, "_clip_stash", {})
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
def process_unet_state_dict(self, state_dict):
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
# SAM3.1: remap tracker.model.* -> tracker.*
for k in list(state_dict.keys()):
if k.startswith("tracker.model."):
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
state_dict.pop(k)
# Split fused QKV projections
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
t = state_dict.pop(k)
base, suffix = k.rsplit(".in_proj_", 1)
s = ".weight" if suffix == "weight" else ".bias"
d = t.shape[0] // 3
state_dict[base + ".q_proj" + s] = t[:d]
state_dict[base + ".k_proj" + s] = t[d:2*d]
state_dict[base + ".v_proj" + s] = t[2*d:]
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
for k in list(state_dict.keys()):
if "sam_mask_decoder.transformer." not in k:
continue
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
return model_base.SAM3(self, device=device)
def clip_target(self, state_dict={}):
import comfy.text_encoders.sam3_clip
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
models += [SVD_img2vid]

View File

@ -0,0 +1,97 @@
import re
from comfy import sd1_clip
SAM3_CLIP_CONFIG = {
"architectures": ["CLIPTextModel"],
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"max_position_embeddings": 32,
"projection_dim": 512,
"vocab_size": 49408,
"layer_norm_eps": 1e-5,
"eos_token_id": 49407,
}
class SAM3ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
class SAM3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
self.disable_weights = True
def _parse_prompts(text):
"""Split comma-separated prompts with optional :N max detections per category"""
text = text.replace("(", "").replace(")", "")
parts = [p.strip() for p in text.split(",") if p.strip()]
result = []
for part in parts:
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
if m:
text_part = m.group(1).strip()
val = m.group(2)
max_det = max(1, round(float(val)))
result.append((text_part, max_det))
else:
result.append((part, 1))
return result
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
parsed = _parse_prompts(text)
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
# Tokenize each prompt part separately, store per-part batches and metadata
inner = getattr(self, self.clip)
per_prompt = []
for prompt_text, max_det in parsed:
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
per_prompt.append((batches, max_det))
# Main output uses first prompt's tokens (for compatibility)
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
return out
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
def encode_token_weights(self, token_weight_pairs):
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
if per_prompt is None:
return super().encode_token_weights(token_weight_pairs)
# Encode each prompt separately, pack into extra dict
inner = getattr(self, self.clip)
multi_cond = []
first_pooled = None
for batches, max_det in per_prompt:
out = inner.encode_token_weights(batches)
cond, pooled = out[0], out[1]
extra = out[2] if len(out) > 2 else {}
if first_pooled is None:
first_pooled = pooled
multi_cond.append({
"cond": cond,
"attention_mask": extra.get("attention_mask"),
"max_detections": max_det,
})
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
main = multi_cond[0]
main_extra = {}
if main["attention_mask"] is not None:
main_extra["attention_mask"] = main["attention_mask"]
main_extra["sam3_multi_cond"] = multi_cond
return (main["cond"], first_pooled, main_extra)

View File

@ -9,6 +9,7 @@ from comfy_api.latest._input import (
CurveInput,
MonotoneCubicCurve,
LinearCurve,
RangeInput,
)
__all__ = [
@ -21,4 +22,5 @@ __all__ = [
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
"RangeInput",
]

View File

@ -1,5 +1,6 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .range_types import RangeInput
from .video_types import VideoInput
__all__ = [
@ -12,4 +13,5 @@ __all__ = [
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
"RangeInput",
]

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import logging
import math
import numpy as np
logger = logging.getLogger(__name__)
class RangeInput:
"""Represents a levels/range adjustment: input range [min, max] with
optional midpoint (gamma control).
Generates a 1D LUT identical to GIMP's levels mapping:
1. Normalize input to [0, 1] using [min, max]
2. Apply gamma correction: pow(value, 1/gamma)
3. Clamp to [0, 1]
The midpoint field is a position in [0, 1] representing where the
midtone falls within [min, max]. It maps to gamma via:
gamma = -log2(midpoint)
So midpoint=0.5 gamma=1.0 (linear).
"""
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
self.min_val = min_val
self.max_val = max_val
self.midpoint = midpoint
@staticmethod
def from_raw(data) -> RangeInput:
if isinstance(data, RangeInput):
return data
if isinstance(data, dict):
return RangeInput(
min_val=float(data.get("min", 0.0)),
max_val=float(data.get("max", 1.0)),
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
)
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table mapping [0, 1] input through this
levels adjustment.
The LUT maps normalized input values (0..1) to output values (0..1),
matching the GIMP levels formula.
"""
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
in_range = self.max_val - self.min_val
if abs(in_range) < 1e-10:
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
# Normalize: map [min, max] → [0, 1]
result = (xs - self.min_val) / in_range
result = np.clip(result, 0.0, 1.0)
# Gamma correction from midpoint
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
gamma = max(-math.log2(self.midpoint), 0.001)
inv_gamma = 1.0 / gamma
mask = result > 0
result[mask] = np.power(result[mask], inv_gamma)
return result
def __repr__(self) -> str:
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"

View File

@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO):
Type = list[int]
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
from comfy_api.input import RangeInput
if TYPE_CHECKING:
Type = RangeInput
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2276,5 +2313,6 @@ __all__ = [
"BoundingBox",
"Curve",
"Histogram",
"Range",
"NodeReplace",
]

View File

@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel):
usage: TaskStatusUsage | None = Field(None)
class GetAssetResponse(BaseModel):
id: str = Field(...)
name: str | None = Field(None)
url: str | None = Field(None)
asset_type: str = Field(...)
group_id: str = Field(...)
status: str = Field(...)
error: TaskStatusError | None = Field(None)
class SeedanceCreateVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
h5_link: str = Field(...)
class SeedanceGetVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
status: str = Field(...)
group_id: str | None = Field(None)
error_code: str | None = Field(None)
error_message: str | None = Field(None)
class SeedanceCreateAssetRequest(BaseModel):
group_id: str = Field(...)
url: str = Field(...)
asset_type: str = Field(...)
name: str | None = Field(None, max_length=64)
project_name: str | None = Field(None)
class SeedanceCreateAssetResponse(BaseModel):
asset_id: str = Field(...)
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007,

View File

@ -1,5 +1,6 @@
import logging
import math
import re
import torch
from typing_extensions import override
@ -11,9 +12,14 @@ from comfy_api_nodes.apis.bytedance import (
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME,
GetAssetResponse,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedance2TaskCreationRequest,
SeedanceCreateAssetRequest,
SeedanceCreateAssetResponse,
SeedanceCreateVisualValidateSessionResponse,
SeedanceGetVisualValidateSessionResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskAudioContent,
@ -44,10 +50,16 @@ from comfy_api_nodes.util import (
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
validate_video_dimensions,
validate_video_duration,
)
from server import PromptServer
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
_VERIFICATION_POLL_TIMEOUT_SEC = 120
_VERIFICATION_POLL_INTERVAL_SEC = 3
SEEDREAM_MODELS = {
"seedream 5.0 lite": "seedream-5-0-260128",
"seedream-4-5-251128": "seedream-4-5-251128",
@ -96,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
)
async def _resolve_reference_assets(
cls: type[IO.ComfyNode],
asset_ids: list[str],
) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
"""Look up each asset, validate Active status, group by asset_type.
Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://<asset_id>".
"""
image_assets: dict[str, str] = {}
video_assets: dict[str, str] = {}
audio_assets: dict[str, str] = {}
for i, raw_id in enumerate(asset_ids, 1):
asset_id = (raw_id or "").strip()
if not asset_id:
continue
result = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
response_model=GetAssetResponse,
)
if result.status != "Active":
extra = f" {result.error.code}: {result.error.message}" if result.error else ""
raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}")
asset_uri = f"asset://{asset_id}"
if result.asset_type == "Image":
image_assets[asset_id] = asset_uri
elif result.asset_type == "Video":
video_assets[asset_id] = asset_uri
elif result.asset_type == "Audio":
audio_assets[asset_id] = asset_uri
return image_assets, video_assets, audio_assets
_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE)
def _build_asset_labels(
reference_assets: dict[str, str],
image_asset_uris: dict[str, str],
video_asset_uris: dict[str, str],
audio_asset_uris: dict[str, str],
n_reference_images: int,
n_reference_videos: int,
n_reference_audios: int,
) -> dict[int, str]:
"""Map asset slot number (from 'asset_N' keys) to its positional label.
Asset entries are appended to `content` after the reference_images/videos/audios,
so their 1-indexed labels continue from the count of existing same-type refs:
one reference_images entry + one Image-type asset -> asset labelled "Image 2".
"""
image_n = n_reference_images
video_n = n_reference_videos
audio_n = n_reference_audios
labels: dict[int, str] = {}
for slot_key, raw_id in reference_assets.items():
asset_id = (raw_id or "").strip()
if not asset_id:
continue
try:
slot_num = int(slot_key.rsplit("_", 1)[-1])
except ValueError:
continue
if asset_id in image_asset_uris:
image_n += 1
labels[slot_num] = f"Image {image_n}"
elif asset_id in video_asset_uris:
video_n += 1
labels[slot_num] = f"Video {video_n}"
elif asset_id in audio_asset_uris:
audio_n += 1
labels[slot_num] = f"Audio {audio_n}"
return labels
def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str:
"""Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels."""
if not labels:
return prompt
def _sub(m: "re.Match[str]") -> str:
return labels.get(int(m.group(1)), m.group(0))
return _ASSET_REF_RE.sub(_sub, prompt)
async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str:
session = await sync_op(
cls,
ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"),
response_model=SeedanceCreateVisualValidateSessionResponse,
)
logger.warning("Seedance authentication required. Open link: %s", session.h5_link)
h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}"
result = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
response_model=SeedanceGetVisualValidateSessionResponse,
status_extractor=lambda r: r.status,
completed_statuses=["completed"],
failed_statuses=["failed"],
poll_interval=_VERIFICATION_POLL_INTERVAL_SEC,
max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1,
estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1,
extra_text=h5_text,
)
if not result.group_id:
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
PromptServer.instance.send_progress_text(
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
)
return result.group_id
async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str:
if group_id and group_id.strip():
return group_id.strip()
return await _obtain_group_id_via_h5_auth(cls)
async def _create_seedance_asset(
cls: type[IO.ComfyNode],
*,
group_id: str,
url: str,
name: str,
asset_type: str,
) -> str:
req = SeedanceCreateAssetRequest(
group_id=group_id,
url=url,
asset_type=asset_type,
name=name or None,
)
result = await sync_op(
cls,
ApiEndpoint(path="/proxy/seedance/assets", method="POST"),
response_model=SeedanceCreateAssetResponse,
data=req,
)
return result.asset_id
async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse:
"""Poll the newly created asset until its status becomes Active."""
return await poll_op(
cls,
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
response_model=GetAssetResponse,
status_extractor=lambda r: r.status,
completed_statuses=["Active"],
failed_statuses=["Failed"],
poll_interval=5,
max_poll_attempts=1200,
extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}",
)
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
@ -1228,12 +1403,27 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.Image.Input(
"first_frame",
tooltip="First frame image for the video.",
optional=True,
),
IO.Image.Input(
"last_frame",
tooltip="Last frame image for the video.",
optional=True,
),
IO.String.Input(
"first_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the first frame. "
"Mutually exclusive with the first_frame image input.",
optional=True,
),
IO.String.Input(
"last_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the last frame. "
"Mutually exclusive with the last_frame image input.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
@ -1286,24 +1476,54 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
model: dict,
first_frame: Input.Image,
seed: int,
watermark: bool,
first_frame: Input.Image | None = None,
last_frame: Input.Image | None = None,
first_frame_asset_id: str = "",
last_frame_asset_id: str = "",
) -> IO.NodeOutput:
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
model_id = SEEDANCE_MODELS[model["model"]]
first_frame_asset_id = first_frame_asset_id.strip()
last_frame_asset_id = last_frame_asset_id.strip()
if first_frame is not None and first_frame_asset_id:
raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.")
if first_frame is None and not first_frame_asset_id:
raise ValueError("Either first_frame or first_frame_asset_id is required.")
if last_frame is not None and last_frame_asset_id:
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
image_assets: dict[str, str] = {}
if asset_ids_to_resolve:
image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve)
for aid in asset_ids_to_resolve:
if aid not in image_assets:
raise ValueError(f"Asset {aid} is not an Image asset.")
if first_frame_asset_id:
first_frame_url = image_assets[first_frame_asset_id]
else:
first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
content: list[TaskTextContent | TaskImageContent] = [
TaskTextContent(text=model["prompt"]),
TaskImageContent(
image_url=TaskImageContentUrl(
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
),
image_url=TaskImageContentUrl(url=first_frame_url),
role="first_frame",
),
]
if last_frame is not None:
if last_frame_asset_id:
content.append(
TaskImageContent(
image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]),
role="last_frame",
),
)
elif last_frame is not None:
content.append(
TaskImageContent(
image_url=TaskImageContentUrl(
@ -1385,6 +1605,24 @@ def _seedance2_reference_inputs(resolutions: list[str]):
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
IO.Autogrow.Input(
"reference_assets",
template=IO.Autogrow.TemplateNames(
IO.String.Input("reference_asset"),
names=[
"asset_1",
"asset_2",
"asset_3",
"asset_4",
"asset_5",
"asset_6",
"asset_7",
"asset_8",
"asset_9",
],
min=0,
),
),
]
@ -1486,24 +1724,42 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
reference_images = model.get("reference_images", {})
reference_videos = model.get("reference_videos", {})
reference_audios = model.get("reference_audios", {})
reference_assets = model.get("reference_assets", {})
if not reference_images and not reference_videos:
raise ValueError("At least one reference image or video is required.")
reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets(
cls, list(reference_assets.values())
)
if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets:
raise ValueError("At least one reference image or video or asset is required.")
total_images = len(reference_images) + len(reference_image_assets)
if total_images > 9:
raise ValueError(
f"Too many reference images: {total_images} "
f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9."
)
total_videos = len(reference_videos) + len(reference_video_assets)
if total_videos > 3:
raise ValueError(
f"Too many reference videos: {total_videos} "
f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3."
)
total_audios = len(reference_audios) + len(reference_audio_assets)
if total_audios > 3:
raise ValueError(
f"Too many reference audios: {total_audios} "
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
)
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = len(reference_videos) > 0
has_video_input = total_videos > 0
if model.get("auto_downscale") and reference_videos:
max_px = (
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
.get(model["resolution"], {})
.get("max")
)
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
if max_px:
for key in reference_videos:
reference_videos[key] = resize_video_to_pixel_budget(
reference_videos[key], max_px
)
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):
@ -1531,8 +1787,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
if total_audio_duration > 15.1:
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
asset_labels = _build_asset_labels(
reference_assets,
reference_image_assets,
reference_video_assets,
reference_audio_assets,
len(reference_images),
len(reference_videos),
len(reference_audios),
)
prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels)
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
TaskTextContent(text=model["prompt"]),
TaskTextContent(text=prompt_text),
]
for i, key in enumerate(reference_images, 1):
content.append(
@ -1573,6 +1840,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
),
),
)
for url in reference_image_assets.values():
content.append(
TaskImageContent(
image_url=TaskImageContentUrl(url=url),
role="reference_image",
),
)
for url in reference_video_assets.values():
content.append(
TaskVideoContent(video_url=TaskVideoContentUrl(url=url)),
)
for url in reference_audio_assets.values():
content.append(
TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)),
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
@ -1627,6 +1909,156 @@ async def process_video_task(
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
class ByteDanceCreateImageAsset(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ByteDanceCreateImageAsset",
display_name="ByteDance Create Image Asset",
category="api node/image/ByteDance",
description=(
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
"registers it in the given asset group. If group_id is empty, runs a real-person "
"H5 authentication flow to create a new group before adding the asset."
),
inputs=[
IO.Image.Input("image", tooltip="Image to register as a personal asset."),
IO.String.Input(
"group_id",
default="",
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
),
# IO.String.Input(
# "name",
# default="",
# tooltip="Asset name (up to 64 characters).",
# ),
],
outputs=[
IO.String.Output(display_name="asset_id"),
IO.String.Output(display_name="group_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
# is_api_node=True,
)
@classmethod
async def execute(
cls,
image: Input.Image,
group_id: str = "",
# name: str = "",
) -> IO.NodeOutput:
# if len(name) > 64:
# raise ValueError("Name of asset can not be greater then 64 symbols")
validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000)
validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1))
resolved_group = await _resolve_group_id(cls, group_id)
asset_id = await _create_seedance_asset(
cls,
group_id=resolved_group,
url=await upload_image_to_comfyapi(cls, image),
name="",
asset_type="Image",
)
await _wait_for_asset_active(cls, asset_id, resolved_group)
PromptServer.instance.send_progress_text(
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
f"group_id: {resolved_group}",
cls.hidden.unique_id,
)
return IO.NodeOutput(asset_id, resolved_group)
class ByteDanceCreateVideoAsset(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ByteDanceCreateVideoAsset",
display_name="ByteDance Create Video Asset",
category="api node/video/ByteDance",
description=(
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
"registers it in the given asset group. If group_id is empty, runs a real-person "
"H5 authentication flow to create a new group before adding the asset."
),
inputs=[
IO.Video.Input("video", tooltip="Video to register as a personal asset."),
IO.String.Input(
"group_id",
default="",
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
),
# IO.String.Input(
# "name",
# default="",
# tooltip="Asset name (up to 64 characters).",
# ),
],
outputs=[
IO.String.Output(display_name="asset_id"),
IO.String.Output(display_name="group_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
# is_api_node=True,
)
@classmethod
async def execute(
cls,
video: Input.Video,
group_id: str = "",
# name: str = "",
) -> IO.NodeOutput:
# if len(name) > 64:
# raise ValueError("Name of asset can not be greater then 64 symbols")
validate_video_duration(video, min_duration=2, max_duration=15)
validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000)
w, h = video.get_dimensions()
if h > 0:
ratio = w / h
if not (0.4 <= ratio <= 2.5):
raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).")
pixels = w * h
if not (409_600 <= pixels <= 927_408):
raise ValueError(
f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})."
)
fps = float(video.get_frame_rate())
if not (24 <= fps <= 60):
raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.")
resolved_group = await _resolve_group_id(cls, group_id)
asset_id = await _create_seedance_asset(
cls,
group_id=resolved_group,
url=await upload_video_to_comfyapi(cls, video),
name="",
asset_type="Video",
)
await _wait_for_asset_active(cls, asset_id, resolved_group)
PromptServer.instance.send_progress_text(
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
f"group_id: {resolved_group}",
cls.hidden.unique_id,
)
return IO.NodeOutput(asset_id, resolved_group)
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1640,6 +2072,8 @@ class ByteDanceExtension(ComfyExtension):
ByteDance2TextToVideoNode,
ByteDance2FirstLastFrameNode,
ByteDance2ReferenceNode,
ByteDanceCreateImageAsset,
ByteDanceCreateVideoAsset,
]

View File

@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input(
"storyboards",
options=[
@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio;
$rates := $audio
? {"std": 0.112, "pro": 0.14}
: {"std": 0.084, "pro": 0.112};
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.")
@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode):
f"must equal the global duration ({duration}s)."
)
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
mode=mode,
multi_shot=multi_shot,
multi_prompt=multi_prompt_list,
shot_type="customize" if multi_shot else None,
@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True,
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input(
"storyboards",
options=[
@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio;
$rates := $audio
? {"std": 0.112, "pro": 0.14}
: {"std": 0.084, "pro": 0.112};
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.")
@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
prompt=prompt,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
mode=mode,
sound="on" if generate_audio else "off",
multi_shot=multi_shot,
multi_prompt=multi_prompt_list,
@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input(
"storyboards",
options=[
@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio;
$rates := $audio
? {"std": 0.112, "pro": 0.14}
: {"std": 0.084, "pro": 0.112};
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.")
@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode):
image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
mode=mode,
sound="on" if generate_audio else "off",
multi_shot=multi_shot,
multi_prompt=multi_prompt_list,
@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode):
IO.DynamicCombo.Option(
"kling-v3",
[
IO.Combo.Input("resolution", options=["1080p", "720p"]),
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1"],
@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode):
),
expr="""
(
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode):
start_frame: Input.Image | None = None,
) -> IO.NodeOutput:
_ = seed
mode = "pro" if model["resolution"] == "1080p" else "std"
if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
custom_multi_shot = False
if multi_shot["multi_shot"] == "disabled":
shot_type = None
@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode):
cls,
ApiEndpoint(path=poll_path),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Option(
"kling-v3",
[
IO.Combo.Input("resolution", options=["1080p", "720p"]),
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
],
),
],
@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
),
expr="""
(
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
image=image_url,
image_tail=image_tail_url,
prompt=prompt,
mode="pro" if model["resolution"] == "1080p" else "std",
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -357,6 +357,10 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Combo.Input(
"size",
default="auto",
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
options=[
"auto",
"1024x1024",
"1024x1536",
"1536x1024",
"2048x2048",
"2048x1152",
"1152x2048",
"3840x2160",
"2160x3840",
],
tooltip="Image size",
optional=True,
),
@ -427,7 +441,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
),
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5", 'gpt-image-2'],
options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"],
default="gpt-image-2",
optional=True,
),
@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
expr="""
(
$ranges := {
"low": [0.011, 0.02],
"medium": [0.046, 0.07],
"high": [0.167, 0.3]
"gpt-image-1": {
"low": [0.011, 0.02],
"medium": [0.042, 0.07],
"high": [0.167, 0.25]
},
"gpt-image-1.5": {
"low": [0.009, 0.02],
"medium": [0.034, 0.062],
"high": [0.133, 0.22]
},
"gpt-image-2": {
"low": [0.0048, 0.012],
"medium": [0.041, 0.112],
"high": [0.165, 0.43]
}
};
$range := $lookup($ranges, widgets.quality);
$n := widgets.n;
$range := $lookup($lookup($ranges, widgets.model), widgets.quality);
$nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
: {
"type":"range_usd",
"min_usd": $range[0],
"max_usd": $range[1],
"format": { "suffix": " x " & $string($n) & "/Run" }
"min_usd": $range[0] * $n,
"max_usd": $range[1] * $n,
"format": { "suffix": "/Run", "approximate": true }
}
)
""",
@ -483,12 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode):
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
if model in ("gpt-image-1", "gpt-image-1.5"):
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
elif model == "gpt-image-2":
price_extractor = calculate_tokens_price_image_1_5
price_extractor = calculate_tokens_price_image_2_0
if background == "transparent":
raise ValueError("Transparent background is not supported for GPT Image 2 model")
else:
raise ValueError(f"Unknown model: {model}")

View File

@ -156,6 +156,7 @@ async def poll_op(
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> M:
raw = await poll_op_raw(
cls,
@ -176,6 +177,7 @@ async def poll_op(
estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout,
extra_text=extra_text,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@ -260,6 +262,7 @@ async def poll_op_raw(
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
extra_text: str | None = None,
) -> dict[str, Any]:
"""
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
@ -299,6 +302,7 @@ async def poll_op_raw(
price=state.price,
is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed),
extra_text=extra_text,
)
await asyncio.sleep(1.0)
except Exception as exc:
@ -389,6 +393,7 @@ async def poll_op_raw(
price=state.price,
is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed),
extra_text=extra_text,
)
return resp_json
@ -462,6 +467,7 @@ def _display_time_progress(
price: float | None = None,
is_queued: bool | None = None,
processing_elapsed_seconds: int | None = None,
extra_text: str | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -469,7 +475,8 @@ def _display_time_progress(
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price)
text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
_display_text(node_cls, text, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]:

View File

@ -1,6 +1,7 @@
import nodes
import node_helpers
import torch
import torchaudio
import comfy.model_management
import comfy.model_sampling
import comfy.samplers
@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
sample_rate = reference_audio["sample_rate"]
vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = reference_audio["waveform"]
audio_latents = audio_vae.encode(waveform.movedim(1, -1))
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}

View File

@ -1,5 +1,6 @@
import json
from comfy.comfy_types.node_typing import IO
import torch
# Preview Any - original implement from
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
@ -19,6 +20,7 @@ class PreviewAny():
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
def main(self, source=None):
torch.set_printoptions(edgeitems=6)
value = 'None'
if isinstance(source, str):
value = source
@ -33,6 +35,7 @@ class PreviewAny():
except Exception:
value = 'source exists, but could not be serialized.'
torch.set_printoptions()
return {"ui": {"text": (value,)}, "result": (value,)}
NODE_CLASS_MAPPINGS = {

529
comfy_extras/nodes_sam3.py Normal file
View File

@ -0,0 +1,529 @@
"""
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
"""
from typing_extensions import override
import json
import os
import torch
import torch.nn.functional as F
import comfy.model_management
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io, ui
import av
from fractions import Fraction
def _extract_text_prompts(conditioning, device, dtype):
"""Extract list of (text_embeddings, text_mask) from conditioning."""
cond_meta = conditioning[0][1]
multi = cond_meta.get("sam3_multi_cond")
prompts = []
if multi is not None:
for entry in multi:
emb = entry["cond"].to(device=device, dtype=dtype)
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
if mask is None:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, entry.get("max_detections", 1)))
else:
emb = conditioning[0][0].to(device=device, dtype=dtype)
mask = cond_meta.get("attention_mask")
if mask is not None:
mask = mask.to(device)
else:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, 1))
return prompts
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
Returns: [1, H, W] binary mask
"""
def _coarse_fallback():
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
mode="bilinear", align_corners=False)[0] > 0).float()
if iterations <= 0:
return _coarse_fallback()
pad_frac = 0.1
x1, y1, x2, y2 = box_xyxy.tolist()
bw, bh = x2 - x1, y2 - y1
cx1 = max(0, int(x1 - bw * pad_frac))
cy1 = max(0, int(y1 - bh * pad_frac))
cx2 = min(W, int(x2 + bw * pad_frac))
cy2 = min(H, int(y2 + bh * pad_frac))
if cx2 <= cx1 or cy2 <= cy1:
return _coarse_fallback()
crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3]
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
crop_frame = crop_1008.to(device=device, dtype=dtype)
crop_h, crop_w = cy2 - cy1, cx2 - cx1
# Crop coarse mask and refine via SAM on the cropped image
mask_h, mask_w = coarse_mask.shape[-2:]
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
if mx2 <= mx1 or my2 <= my1:
return _coarse_fallback()
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
for _ in range(iterations):
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
class SAM3_Detect(io.ComfyNode):
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_Detect",
display_name="SAM3 Detect",
category="detection/",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[
io.Model.Input("model", display_name="model"),
io.Image.Input("image", display_name="image"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
],
outputs=[
io.Mask.Output("masks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
def _boxes_to_tensor(box_list):
coords = []
for d in box_list:
cx = (d["x"] + d["width"] / 2) / W
cy = (d["y"] + d["height"] / 2) / H
coords.append([cx, cy, d["width"] / W, d["height"] / H])
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
per_frame_boxes = None
if bboxes is not None:
if isinstance(bboxes, dict):
# Single box → same for all frames
shared = _boxes_to_tensor([bboxes])
per_frame_boxes = [shared] * B
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
# list[list[dict]] → per-frame boxes
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
# Pad to B if fewer frames provided
while len(per_frame_boxes) < B:
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
elif isinstance(bboxes, list) and len(bboxes) > 0:
# list[dict] → same boxes for all frames
shared = _boxes_to_tensor(bboxes)
per_frame_boxes = [shared] * B
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
pos_pts = json.loads(positive_coords) if positive_coords else []
neg_pts = json.loads(negative_coords) if negative_coords else []
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
# Build point inputs for tracker SAM decoder path
point_inputs = None
if has_points:
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
point_inputs = {
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
}
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
has_text = len(cond_list) > 0
# Run per-image through detector (text/boxes) and/or tracker (points)
all_bbox_dicts = []
all_masks = []
pbar = comfy.utils.ProgressBar(B)
for b in range(B):
frame = image_in[b:b+1].to(device=device, dtype=dtype)
b_boxes = None
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
frame_bbox_dicts = []
frame_masks = []
# Point prompts: tracker SAM decoder path with iterative refinement
if point_inputs is not None:
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Box prompts: SAM decoder path (segment inside each box)
if b_boxes is not None and not has_text:
for box_cxcywh in b_boxes[0]:
cx, cy, bw, bh = box_cxcywh.tolist()
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
device=device, dtype=dtype)
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Text prompts: run detector per text prompt (each detects one category)
for text_embeddings, text_mask, max_det in cond_list:
results = sam3_model(
frame, text_embeddings=text_embeddings, text_mask=text_mask,
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
pred_boxes = results["boxes"][0]
scores = results["scores"][0]
masks = results["masks"][0]
probs = scores.sigmoid()
keep = probs > threshold
kept_boxes = pred_boxes[keep].cpu()
kept_scores = probs[keep].cpu()
kept_masks = masks[keep]
order = kept_scores.argsort(descending=True)[:max_det]
kept_boxes = kept_boxes[order]
kept_scores = kept_scores[order]
kept_masks = kept_masks[order]
for box, score in zip(kept_boxes, kept_scores):
frame_bbox_dicts.append({
"x": float(box[0]), "y": float(box[1]),
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
"score": float(score),
})
for m, box in zip(kept_masks, kept_boxes):
frame_masks.append(_refine_mask(
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
all_bbox_dicts.append(frame_bbox_dicts)
if len(frame_masks) > 0:
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
if individual_masks:
all_masks.append(combined)
else:
all_masks.append((combined > 0).any(dim=0).float())
else:
if individual_masks:
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
else:
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
pbar.update(1)
idev = comfy.model_management.intermediate_device()
all_masks = [m.to(idev) for m in all_masks]
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
return io.NodeOutput(mask_out, all_bbox_dicts)
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
class SAM3_VideoTrack(io.ComfyNode):
"""Track objects across video frames using SAM3's memory-based tracker."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track",
category="detection/",
search_aliases=["sam3", "video", "track", "propagate"],
inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
SAM3TrackData.Output("track_data", display_name="track_data"),
],
)
@classmethod
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
N, H, W, C = images.shape
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
init_masks = None
if initial_mask is not None:
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(N)
text_prompts = None
if conditioning is not None and len(conditioning) > 0:
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
elif initial_mask is None:
raise ValueError("Either initial_mask or conditioning must be provided")
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
class SAM3_TrackPreview(io.ComfyNode):
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True),
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
],
is_output_node=True,
)
COLORS = [
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
]
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
@staticmethod
def _get_glyphs(device, scale=3):
key = (device, scale)
if key in SAM3_TrackPreview._glyph_cache:
return SAM3_TrackPreview._glyph_cache[key]
atlas = torch.tensor([
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
], dtype=torch.bool)
glyphs, outlines = [], []
for d in range(10):
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
glyphs.append(g.to(device))
outlines.append(o.to(device))
gh, gw = glyphs[0].shape
oh, ow = outlines[0].shape
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
return SAM3_TrackPreview._glyph_cache[key]
@staticmethod
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
H, W = frame.shape[:2]
device = frame.device
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
digs = [int(d) for d in str(number)]
total_w = len(digs) * (gw + scale) - scale
x0 = cx - total_w // 2
y0 = cy - gh // 2
for i, d in enumerate(digs):
dx = x0 + i * (gw + scale)
# Black outline
oy0, ox0 = y0 - 1, dx - 1
osy1, osx1 = max(0, -oy0), max(0, -ox0)
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
if osy2 > osy1 and osx2 > osx1:
fy1, fx1 = oy0 + osy1, ox0 + osx1
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
# Colored fill
sy1, sx1 = max(0, -y0), max(0, -dx)
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
if sy2 > sy1 and sx2 > sx1:
fy1, fx1 = y0 + sy1, dx + sx1
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
@classmethod
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if images is not None:
H, W = images.shape[1], images.shape[2]
if packed is None:
N, N_obj = track_data["n_frames"], 0
else:
N, N_obj = packed.shape[0], packed.shape[1]
import uuid
gpu = comfy.model_management.get_torch_device()
temp_dir = folder_paths.get_temp_directory()
filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4"
filepath = os.path.join(temp_dir, filename)
with av.open(filepath, mode='w') as output:
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
stream.width = W
stream.height = H
stream.pix_fmt = 'yuv420p'
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
frame_np = frame_cpu.numpy()
if N_obj > 0:
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
device=gpu, dtype=torch.float32)
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
for t in range(N):
if images is not None and t < images.shape[0]:
frame = images[t].clone()
else:
frame = torch.zeros(H, W, 3)
if N_obj > 0:
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
frame_gpu = frame.to(gpu)
bool_masks = frame_masks > 0.5
any_mask = bool_masks.any(dim=0)
if any_mask.any():
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
color_overlay = colors_t[obj_idx_map]
mask_3d = any_mask.unsqueeze(-1)
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
output.mux(stream.encode(None))
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
class SAM3_TrackToMask(io.ComfyNode):
"""Select tracked objects by index and output as mask."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="",
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
],
outputs=[
io.Mask.Output("masks", display_name="masks"),
],
)
@classmethod
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if packed is None:
N = track_data["n_frames"]
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
N, N_obj = packed.shape[0], packed.shape[1]
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
indices = [i for i in indices if 0 <= i < N_obj]
else:
indices = list(range(N_obj))
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)
class SAM3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SAM3_Detect,
SAM3_VideoTrack,
SAM3_TrackPreview,
SAM3_TrackToMask,
]
async def comfy_entrypoint() -> SAM3Extension:
return SAM3Extension()

View File

@ -811,11 +811,30 @@ class PromptExecutor:
self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
if visiting is None:
visiting = []
unique_id = item
if unique_id in validated:
return validated[unique_id]
if unique_id in visiting:
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
for node_id in cycle_nodes:
validated[node_id] = (False, [{
"type": "dependency_cycle",
"message": "Dependency cycle detected",
"details": cycle_path,
"extra_info": {
"node_id": node_id,
"cycle_nodes": cycle_nodes,
}
}], node_id)
return validated[unique_id]
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error)
continue
try:
r = await validate_inputs(prompt_id, prompt, o_id, validated)
visiting.append(unique_id)
try:
r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting)
finally:
visiting.pop()
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
ret = (True, [], unique_id)
ret = validated.get(unique_id, (True, [], unique_id))
# Recursive cycle detection may have already populated an error on us. Join it.
ret = (
ret[0] and valid is True and not errors,
ret[1] + [error for error in errors if error not in ret[1]],
unique_id,
)
validated[unique_id] = ret
return ret

View File

@ -1 +1 @@
comfyui_manager==4.1
comfyui_manager==4.2.1

View File

@ -2459,6 +2459,7 @@ async def init_builtin_extra_nodes():
"nodes_curve.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py"
]
import_failed = []

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.42.14
comfyui-workflow-templates==0.9.59
comfyui-embedded-docs==0.4.3
comfyui-workflow-templates==0.9.62
comfyui-embedded-docs==0.4.4
torch
torchsde
torchvision
@ -23,7 +23,7 @@ SQLAlchemy>=2.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.12
comfy-aimdo==0.2.14
requests
simpleeval>=1.0.0
blake3

View File

@ -39,7 +39,7 @@ def get_required_packages_versions():
if len(s) == 2:
version_str = s[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}")
continue
out[s[0]] = version_str
return out.copy()