Merge branch 'master' into add-openapi-spec
commit
9f88368030
|
|
@ -2,7 +2,6 @@
|
|||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
|
|
@ -75,7 +74,7 @@ void main() {
|
|||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12;
|
|||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
|
@ -25,7 +24,7 @@ float gaussian(float x, float sigma) {
|
|||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@
|
|||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
|
@ -19,7 +18,7 @@ float getLuminance(vec3 color) {
|
|||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@
|
|||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
|
||||
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
|
||||
"from_input"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@
|
|||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
|
||||
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
|
||||
"from_input"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@
|
|||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
|
||||
"from_input"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@
|
|||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
|
||||
"from_input"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,596 @@
|
|||
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops import roi_align
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
|
||||
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
|
||||
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
|
||||
|
||||
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
|
||||
from comfy.ops import cast_to_input
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
cx, cy, w, h = x.unbind(-1)
|
||||
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
|
||||
|
||||
|
||||
def gen_sineembed_for_position(pos_tensor, num_feats=256):
|
||||
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
|
||||
assert num_feats % 2 == 0
|
||||
hdim = num_feats // 2
|
||||
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
|
||||
embeds = []
|
||||
for c in range(pos_tensor.shape[-1]):
|
||||
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
|
||||
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
|
||||
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
|
||||
|
||||
|
||||
class SplitMHA(nn.Module):
|
||||
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
|
||||
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, q_input, k_input=None, v_input=None, mask=None):
|
||||
q = self.q_proj(q_input)
|
||||
if k_input is None:
|
||||
k = self.k_proj(q_input)
|
||||
v = self.v_proj(q_input)
|
||||
else:
|
||||
k = self.k_proj(k_input)
|
||||
v = self.v_proj(v_input if v_input is not None else k_input)
|
||||
if mask is not None and mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
|
||||
dtype = q.dtype # manual_cast may produce mixed dtypes
|
||||
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class MLPWithNorm(nn.Module):
|
||||
"""MLP with residual connection and output LayerNorm."""
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
|
||||
self.layers = nn.ModuleList([
|
||||
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
|
||||
self.residual = residual and (input_dim == output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
orig = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i < len(self.layers) - 1:
|
||||
x = F.relu(x)
|
||||
if self.residual:
|
||||
x = x + orig
|
||||
return self.out_norm(x)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, pos, text_memory=None, text_mask=None):
|
||||
normed = self.norm1(x)
|
||||
q_k = normed + pos
|
||||
x = x + self.self_attn(q_k, q_k, normed)
|
||||
if text_memory is not None:
|
||||
normed = self.norm2(x)
|
||||
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
|
||||
normed = self.norm3(x)
|
||||
x = x + self.linear2(F.relu(self.linear1(normed)))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""Checkpoint: transformer.encoder.layers.N.*"""
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, pos, text_memory=None, text_mask=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos, text_memory, text_mask)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
|
||||
q_k = x + x_pos
|
||||
x = self.norm2(x + self.self_attn(q_k, q_k, x))
|
||||
if text_memory is not None:
|
||||
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
|
||||
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
|
||||
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
|
||||
num_queries=200, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
|
||||
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
|
||||
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
|
||||
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
||||
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _inverse_sigmoid(x):
|
||||
return torch.log(x / (1 - x + 1e-6) + 1e-6)
|
||||
|
||||
def _compute_box_rpb(self, ref_points, H, W):
|
||||
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
|
||||
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
|
||||
B, Q, _ = boxes_xyxy.shape
|
||||
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
|
||||
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
|
||||
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
|
||||
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
|
||||
|
||||
log2_8 = float(math.log2(8))
|
||||
def log_scale(d):
|
||||
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
|
||||
|
||||
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
|
||||
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
|
||||
|
||||
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
|
||||
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
|
||||
return torch.cat([pres_bias, bias], dim=2)
|
||||
|
||||
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
|
||||
B = memory.shape[0]
|
||||
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
|
||||
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
|
||||
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
|
||||
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
|
||||
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
|
||||
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
|
||||
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
|
||||
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
|
||||
if layer_idx < len(self.layers) - 1:
|
||||
ref_inv = self._inverse_sigmoid(ref_points)
|
||||
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
|
||||
|
||||
query_out = self.norm(tgt)
|
||||
ref_inv = self._inverse_sigmoid(ref_points)
|
||||
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
|
||||
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
|
||||
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
|
||||
num_queries=200, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
|
||||
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
class GeometryEncoder(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.roi_size = roi_size
|
||||
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
|
||||
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
|
||||
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
|
||||
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
|
||||
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
|
||||
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
|
||||
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.encode = nn.ModuleList([
|
||||
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
|
||||
def _encode_points(self, coords, labels, img_feat_2d):
|
||||
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
|
||||
B, N, _ = coords.shape
|
||||
embed = self.points_direct_project(coords)
|
||||
# Pool features from backbone at point locations via grid_sample
|
||||
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
|
||||
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
|
||||
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
|
||||
# Positional encoding of coordinates
|
||||
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
|
||||
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
|
||||
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
|
||||
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
|
||||
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
|
||||
return embed
|
||||
|
||||
def _encode_boxes(self, boxes, labels, img_feat_2d):
|
||||
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
|
||||
B, N, _ = boxes.shape
|
||||
embed = self.boxes_direct_project(boxes)
|
||||
# ROI align from backbone at box regions
|
||||
H, W = img_feat_2d.shape[-2:]
|
||||
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
|
||||
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
|
||||
boxes_scaled = boxes_xyxy * scale
|
||||
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
|
||||
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
|
||||
embed = embed + proj
|
||||
# Positional encoding of box center + size
|
||||
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
|
||||
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
|
||||
enc = enc.view(B, N, -1)
|
||||
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
|
||||
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
|
||||
return embed
|
||||
|
||||
def forward(self, points=None, boxes=None, image_features=None):
|
||||
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
|
||||
# Prepare 2D image features for pooling
|
||||
img_feat_2d = None
|
||||
if image_features is not None:
|
||||
B = image_features.shape[0]
|
||||
HW, C = image_features.shape[1], image_features.shape[2]
|
||||
hw = int(math.sqrt(HW))
|
||||
img_normed = self.img_pre_norm(image_features)
|
||||
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
|
||||
|
||||
embeddings = []
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
|
||||
if boxes is not None:
|
||||
B = boxes.shape[0]
|
||||
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
|
||||
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
|
||||
if not embeddings:
|
||||
return None
|
||||
geo = torch.cat(embeddings, dim=1)
|
||||
geo = self.norm(geo)
|
||||
if image_features is not None:
|
||||
for layer in self.encode:
|
||||
geo = layer(geo, torch.zeros_like(geo), image_features)
|
||||
geo = self.encode_norm(geo)
|
||||
return self.final_proj(geo)
|
||||
|
||||
|
||||
class PixelDecoder(nn.Module):
|
||||
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
|
||||
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
|
||||
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
|
||||
|
||||
def forward(self, backbone_features):
|
||||
prev = backbone_features[-1]
|
||||
for i, feat in enumerate(backbone_features[:-1][::-1]):
|
||||
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
|
||||
return prev
|
||||
|
||||
|
||||
class MaskPredictor(nn.Module):
|
||||
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, query_embeddings, pixel_features):
|
||||
mask_embed = self.mask_embed(query_embeddings)
|
||||
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
|
||||
|
||||
|
||||
class SegmentationHead(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
|
||||
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
|
||||
if encoder_hidden_states is not None and prompt is not None:
|
||||
enc_normed = self.cross_attn_norm(encoder_hidden_states)
|
||||
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
|
||||
encoder_hidden_states = enc_cross + encoder_hidden_states
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
|
||||
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
|
||||
backbone_features = list(backbone_features)
|
||||
backbone_features[-1] = encoder_visual
|
||||
|
||||
pixel_features = self.pixel_decoder(backbone_features)
|
||||
instance_features = self.instance_seg_head(pixel_features)
|
||||
masks = self.mask_predictor(query_embeddings, instance_features)
|
||||
return masks
|
||||
|
||||
|
||||
class DotProductScoring(nn.Module):
|
||||
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.scale = 1.0 / (d_model ** 0.5)
|
||||
|
||||
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
|
||||
prompt = self.prompt_mlp(prompt_embeddings)
|
||||
if prompt_mask is not None:
|
||||
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
|
||||
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
|
||||
else:
|
||||
pooled = prompt.mean(dim=1)
|
||||
hs = self.hs_proj(query_embeddings)
|
||||
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
|
||||
scores = torch.matmul(hs, pp)
|
||||
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
|
||||
|
||||
|
||||
class SAM3Detector(nn.Module):
|
||||
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
image_model = kwargs.pop("image_model", "SAM3")
|
||||
for k in ("num_heads", "num_head_channels"):
|
||||
kwargs.pop(k, None)
|
||||
multiplex = image_model == "SAM31"
|
||||
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
|
||||
self.scalp = 0 if multiplex else 1
|
||||
self.backbone = nn.ModuleDict({
|
||||
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
|
||||
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
|
||||
})
|
||||
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
|
||||
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def _get_backbone_features(self, images):
|
||||
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
|
||||
bb = self.backbone["vision_backbone"]
|
||||
if bb.multiplex:
|
||||
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
|
||||
else:
|
||||
all_f, all_p, tf, tp = bb(images, need_tracker=True)
|
||||
return all_f, all_p, tf, tp
|
||||
|
||||
@staticmethod
|
||||
def _run_geo_layer(layer, x, memory, memory_pos):
|
||||
x = x + layer.self_attn(layer.norm1(x))
|
||||
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
|
||||
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
|
||||
return x
|
||||
|
||||
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
|
||||
points=None, boxes=None):
|
||||
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
|
||||
B = features[0].shape[0]
|
||||
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
|
||||
seg_features = features
|
||||
if self.scalp > 0:
|
||||
features = features[:-self.scalp]
|
||||
positions = positions[:-self.scalp]
|
||||
enc_feat, enc_pos = features[-1], positions[-1]
|
||||
_, _, H, W = enc_feat.shape
|
||||
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
|
||||
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
|
||||
|
||||
has_prompts = text_embeddings is not None or points is not None or boxes is not None
|
||||
if has_prompts:
|
||||
geo_enc = self.geometry_encoder
|
||||
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
|
||||
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
|
||||
for layer in geo_enc.encode:
|
||||
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
|
||||
geo_cls = geo_enc.encode_norm(geo_cls)
|
||||
if text_embeddings is not None and text_embeddings.shape[0] != B:
|
||||
text_embeddings = text_embeddings.expand(B, -1, -1)
|
||||
if text_mask is not None and text_mask.shape[0] != B:
|
||||
text_mask = text_mask.expand(B, -1)
|
||||
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
|
||||
text_embeddings = torch.cat(parts, dim=1)
|
||||
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
|
||||
if text_mask is not None:
|
||||
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
|
||||
else:
|
||||
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
|
||||
|
||||
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
|
||||
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
|
||||
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
|
||||
|
||||
if text_embeddings is not None:
|
||||
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
|
||||
else:
|
||||
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
|
||||
|
||||
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
|
||||
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
|
||||
|
||||
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
|
||||
features, positions, _, _ = self._get_backbone_features(images)
|
||||
|
||||
if text_embeddings is not None:
|
||||
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
|
||||
if text_mask is not None:
|
||||
text_mask = text_mask.bool()
|
||||
|
||||
boxes_xyxy, scores, masks, dec_out = self._detect(
|
||||
features, positions, text_embeddings, text_mask, points, boxes)
|
||||
|
||||
if orig_size is not None:
|
||||
oh, ow = orig_size
|
||||
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
|
||||
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
|
||||
|
||||
return {
|
||||
"boxes": boxes_xyxy,
|
||||
"scores": scores,
|
||||
"masks": masks,
|
||||
"presence": dec_out.get("presence"),
|
||||
}
|
||||
|
||||
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
|
||||
"""Run detection using a pre-computed ViTDet trunk output.
|
||||
|
||||
text_embeddings must already be resized through language_backbone.resizer.
|
||||
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
|
||||
"""
|
||||
bb = self.backbone["vision_backbone"]
|
||||
features = [conv(trunk_out) for conv in bb.convs]
|
||||
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
|
||||
|
||||
if text_mask is not None:
|
||||
text_mask = text_mask.bool()
|
||||
|
||||
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
|
||||
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
|
||||
|
||||
|
||||
class SAM3Model(nn.Module):
|
||||
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
image_model = kwargs.get("image_model", "SAM3")
|
||||
tracker_cls = TRACKER_CLASSES[image_model]
|
||||
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
|
||||
def forward(self, images, **kwargs):
|
||||
return self.detector(images, **kwargs)
|
||||
|
||||
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
|
||||
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
|
||||
|
||||
Args:
|
||||
images: [B, 3, 1008, 1008] preprocessed images
|
||||
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
|
||||
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
|
||||
mask_inputs: [B, 1, H, W] coarse mask logits to refine
|
||||
Returns:
|
||||
[B, 1, image_size, image_size] high-res mask logits
|
||||
"""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
if bb.multiplex:
|
||||
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
|
||||
else:
|
||||
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
|
||||
if self.detector.scalp > 0:
|
||||
tracker_features = tracker_features[:-self.detector.scalp]
|
||||
tracker_positions = tracker_positions[:-self.detector.scalp]
|
||||
|
||||
high_res = list(tracker_features[:-1])
|
||||
backbone_feat = tracker_features[-1]
|
||||
B, C, H, W = backbone_feat.shape
|
||||
# Add no-memory embedding (init frame path)
|
||||
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
|
||||
if no_mem is None:
|
||||
no_mem = getattr(self.tracker, 'no_mem_embed', None)
|
||||
if no_mem is not None:
|
||||
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
|
||||
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
|
||||
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
|
||||
|
||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
||||
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
|
||||
backbone_features=backbone_feat,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
box_inputs=box_inputs,
|
||||
high_res_features=high_res,
|
||||
multimask_output=(0 < num_pts <= 1),
|
||||
)
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
def backbone_fn(frame, frame_idx=None):
|
||||
trunk_out = bb.trunk(frame)
|
||||
if bb.multiplex:
|
||||
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
|
||||
else:
|
||||
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
|
||||
return tf, tp, trunk_out
|
||||
|
||||
detect_fn = None
|
||||
if text_prompts:
|
||||
resizer = self.detector.backbone["language_backbone"]["resizer"]
|
||||
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
|
||||
def detect_fn(trunk_out):
|
||||
all_scores, all_masks = [], []
|
||||
for emb, mask in resized:
|
||||
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
|
||||
all_scores.append(det["scores"])
|
||||
all_masks.append(det["masks"])
|
||||
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
|
||||
|
||||
if hasattr(self.tracker, 'track_video_with_detection'):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
|
|
@ -0,0 +1,425 @@
|
|||
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ops import cast_to_input
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
|
||||
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return torch.sigmoid(x) if self.sigmoid_output else x
|
||||
|
||||
|
||||
class SAMAttention(nn.Module):
|
||||
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
internal_dim = embedding_dim // downsample_rate
|
||||
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
||||
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
|
||||
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, queries, keys, query_pe, key_pe):
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.norm1(self.self_attn(queries, queries, queries))
|
||||
else:
|
||||
q = queries + query_pe
|
||||
queries = self.norm1(queries + self.self_attn(q, q, queries))
|
||||
q, k = queries + query_pe, keys + key_pe
|
||||
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
|
||||
queries = self.norm3(queries + self.mlp(queries))
|
||||
q, k = queries + query_pe, keys + key_pe
|
||||
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, image_embedding, image_pe, point_embedding):
|
||||
queries, keys = point_embedding, image_embedding
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(queries, keys, point_embedding, image_pe)
|
||||
q, k = queries + point_embedding, keys + image_pe
|
||||
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
|
||||
return queries, keys
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""Fourier feature positional encoding with random gaussian projection."""
|
||||
def __init__(self, num_pos_feats=64, scale=None):
|
||||
super().__init__()
|
||||
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
|
||||
|
||||
def _encode(self, normalized_coords):
|
||||
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
|
||||
orig_dtype = normalized_coords.dtype
|
||||
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
|
||||
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
|
||||
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
|
||||
|
||||
def forward(self, size, device=None):
|
||||
h, w = size
|
||||
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
|
||||
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
|
||||
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
|
||||
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
def forward_with_coords(self, pixel_coords, image_size):
|
||||
norm = pixel_coords.clone()
|
||||
norm[:, :, 0] /= image_size[1]
|
||||
norm[:, :, 1] /= image_size[0]
|
||||
return self._encode(norm)
|
||||
|
||||
|
||||
# ViTDet backbone + FPN neck
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int):
|
||||
B, H, W, C = x.shape
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
|
||||
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
|
||||
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
||||
ids = torch.stack([(t % end_x) * scale_pos,
|
||||
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
|
||||
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
|
||||
|
||||
|
||||
class _ViTMLP(nn.Module):
|
||||
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden = int(dim * mlp_ratio)
|
||||
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
|
||||
self.act = nn.GELU()
|
||||
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""ViTDet multi-head attention with fused QKV projection."""
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs_cis=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
|
||||
if self.use_rope and freqs_cis is not None:
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, freqs_cis=None):
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
|
||||
x = self.attn(x, freqs_cis=freqs_cis)
|
||||
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
else:
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H * W, C)
|
||||
x = self.attn(x, freqs_cis=freqs_cis)
|
||||
x = x.view(B, H, W, C)
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class ViTDet(nn.Module):
|
||||
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
|
||||
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.global_att_blocks = set(global_att_blocks)
|
||||
|
||||
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
|
||||
|
||||
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
grid_size = img_size // patch_size
|
||||
pretrain_grid = pretrain_img_size // patch_size
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
is_global = i in self.global_att_blocks
|
||||
self.blocks.append(Block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias,
|
||||
window_size=0 if is_global else window_size,
|
||||
use_rope=use_rope,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
))
|
||||
|
||||
if use_rope:
|
||||
rope_scale = pretrain_grid / grid_size
|
||||
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
|
||||
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
|
||||
else:
|
||||
self.freqs_cis = None
|
||||
self.freqs_cis_window = None
|
||||
|
||||
def _get_pos_embed(self, num_tokens):
|
||||
pos = self.pos_embed
|
||||
if pos.shape[1] == num_tokens:
|
||||
return pos
|
||||
cls_pos = pos[:, :1]
|
||||
spatial_pos = pos[:, 1:]
|
||||
old_size = int(math.sqrt(spatial_pos.shape[1]))
|
||||
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
|
||||
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
|
||||
tiles_h = new_size // old_size + 1
|
||||
tiles_w = new_size // old_size + 1
|
||||
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
|
||||
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
|
||||
return torch.cat([cls_pos, tiled], dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
B, C, Hp, Wp = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
|
||||
|
||||
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
|
||||
x = x + pos[:, 1:Hp * Wp + 1]
|
||||
|
||||
x = x.view(B, Hp, Wp, C)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
freqs_cis_global = self.freqs_cis
|
||||
freqs_cis_win = self.freqs_cis_window
|
||||
if freqs_cis_global is not None:
|
||||
freqs_cis_global = cast_to_input(freqs_cis_global, x)
|
||||
if freqs_cis_win is not None:
|
||||
freqs_cis_win = cast_to_input(freqs_cis_win, x)
|
||||
|
||||
for block in self.blocks:
|
||||
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
|
||||
x = block(x, freqs_cis=fc)
|
||||
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class FPNScaleConv(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
if scale == 4.0:
|
||||
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
proj_in = in_dim // 4
|
||||
elif scale == 2.0:
|
||||
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
proj_in = in_dim // 2
|
||||
elif scale == 1.0:
|
||||
proj_in = in_dim
|
||||
elif scale == 0.5:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
proj_in = in_dim
|
||||
self.scale = scale
|
||||
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
|
||||
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale == 4.0:
|
||||
x = F.gelu(self.dconv_2x2_0(x))
|
||||
x = self.dconv_2x2_1(x)
|
||||
elif self.scale == 2.0:
|
||||
x = self.dconv_2x2(x)
|
||||
elif self.scale == 0.5:
|
||||
x = self.pool(x)
|
||||
x = self.conv_1x1(x)
|
||||
x = self.conv_3x3(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""2D sinusoidal position encoding (DETR-style) with result caching."""
|
||||
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
|
||||
super().__init__()
|
||||
assert num_pos_feats % 2 == 0
|
||||
self.half_dim = num_pos_feats // 2
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
self.scale = scale if scale is not None else 2 * math.pi
|
||||
self._cache = {}
|
||||
|
||||
def _sincos(self, vals):
|
||||
"""Encode 1D values to interleaved sin/cos features."""
|
||||
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
|
||||
raw = vals[..., None] * self.scale / freqs
|
||||
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
|
||||
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
|
||||
pos_x = x[:, None] * self.scale / dim_t
|
||||
pos_y = y[:, None] * self.scale / dim_t
|
||||
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
||||
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
||||
return pos_x, pos_y
|
||||
|
||||
def encode_boxes(self, cx, cy, w, h):
|
||||
"""Encode box center + size to [N, d_model+2] features."""
|
||||
pos_x, pos_y = self._encode_xy(cx, cy)
|
||||
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
key = (H, W, x.device)
|
||||
if key not in self._cache:
|
||||
gy = torch.arange(H, dtype=torch.float32, device=x.device)
|
||||
gx = torch.arange(W, dtype=torch.float32, device=x.device)
|
||||
if self.normalize:
|
||||
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
|
||||
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
|
||||
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
|
||||
return self._cache[key].expand(B, -1, -1, -1)
|
||||
|
||||
|
||||
class SAM3VisionBackbone(nn.Module):
|
||||
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
|
||||
self.multiplex = multiplex
|
||||
|
||||
fpn_args = dict(device=device, dtype=dtype, operations=operations)
|
||||
if multiplex:
|
||||
scales = [4.0, 2.0, 1.0]
|
||||
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
else:
|
||||
scales = [4.0, 2.0, 1.0, 0.5]
|
||||
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
|
||||
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
|
||||
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
|
||||
|
||||
if tracker_only:
|
||||
# Skip detector FPN when only tracker features are needed (video tracking)
|
||||
if self.multiplex:
|
||||
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
|
||||
else:
|
||||
tracker_convs = self.sam2_convs
|
||||
tracker_features = [conv(backbone_out) for conv in tracker_convs]
|
||||
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
|
||||
return None, None, tracker_features, tracker_positions
|
||||
|
||||
features = [conv(backbone_out) for conv in self.convs]
|
||||
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
|
||||
|
||||
if self.multiplex:
|
||||
if tracker_mode == "propagation":
|
||||
tracker_convs = self.propagation_convs
|
||||
elif tracker_mode == "interactive":
|
||||
tracker_convs = self.interactive_convs
|
||||
else:
|
||||
return features, positions, None, None
|
||||
elif need_tracker:
|
||||
tracker_convs = self.sam2_convs
|
||||
else:
|
||||
return features, positions, None, None
|
||||
|
||||
tracker_features = [conv(backbone_out) for conv in tracker_convs]
|
||||
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
|
||||
return features, positions, tracker_features, tracker_positions
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -54,6 +54,7 @@ import comfy.ldm.anima.model
|
|||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
|
@ -578,8 +579,8 @@ class Stable_Zero123(BaseModel):
|
|||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
||||
self.cc_projection.weight.copy_(cc_projection_weight)
|
||||
self.cc_projection.bias.copy_(cc_projection_bias)
|
||||
self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone())
|
||||
self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone())
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
|
@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel):
|
|||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class SAM3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
|
||||
|
|
|
|||
|
|
@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["image_model"] = "ernie"
|
||||
return dit_config
|
||||
|
||||
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
|
||||
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "SAM3"
|
||||
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
|
||||
dit_config["image_model"] = "SAM31"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
|
@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
# SAM3: detector.* and tracker.* at top level, no common prefix
|
||||
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
|
||||
return ""
|
||||
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
"model.model.", #audio models
|
||||
"net.", #cosmos
|
||||
|
|
|
|||
|
|
@ -1801,7 +1801,7 @@ def debug_memory_summary():
|
|||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
class InterruptProcessingException(BaseException):
|
||||
pass
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
|
|
|||
|
|
@ -685,9 +685,9 @@ class ModelPatcher:
|
|||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
if key not in self.patches and not force_cast:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
|
@ -695,7 +695,7 @@ class ModelPatcher:
|
|||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
else:
|
||||
|
|
@ -703,9 +703,10 @@ class ModelPatcher:
|
|||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if key in self.patches:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
elif inplace_update:
|
||||
|
|
@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||
key = key_param_name_to_key(n, param_key)
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is not None:
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
|
|
@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
for param in params:
|
||||
if param not in ("weight", "bias"):
|
||||
force_load_param(self, param, device_to)
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
|
|
|
|||
|
|
@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE):
|
|||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
class SAM3(supported_models_base.BASE):
|
||||
unet_config = {"image_model": "SAM3"}
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
|
||||
unet_extra_prefix = ""
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
clip_keys = getattr(self, "_clip_stash", {})
|
||||
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
|
||||
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
|
||||
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
|
||||
# SAM3.1: remap tracker.model.* -> tracker.*
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("tracker.model."):
|
||||
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
|
||||
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
|
||||
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
|
||||
state_dict.pop(k)
|
||||
# Split fused QKV projections
|
||||
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
|
||||
t = state_dict.pop(k)
|
||||
base, suffix = k.rsplit(".in_proj_", 1)
|
||||
s = ".weight" if suffix == "weight" else ".bias"
|
||||
d = t.shape[0] // 3
|
||||
state_dict[base + ".q_proj" + s] = t[:d]
|
||||
state_dict[base + ".k_proj" + s] = t[d:2*d]
|
||||
state_dict[base + ".v_proj" + s] = t[2*d:]
|
||||
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
|
||||
for k in list(state_dict.keys()):
|
||||
if "sam_mask_decoder.transformer." not in k:
|
||||
continue
|
||||
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
|
||||
if new_k != k:
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SAM3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
import comfy.text_encoders.sam3_clip
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
|
||||
|
||||
|
||||
class SAM31(SAM3):
|
||||
unet_config = {"image_model": "SAM31"}
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
import re
|
||||
from comfy import sd1_clip
|
||||
|
||||
SAM3_CLIP_CONFIG = {
|
||||
"architectures": ["CLIPTextModel"],
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"max_position_embeddings": 32,
|
||||
"projection_dim": 512,
|
||||
"vocab_size": 49408,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"eos_token_id": 49407,
|
||||
}
|
||||
|
||||
|
||||
class SAM3ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
|
||||
|
||||
|
||||
class SAM3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
|
||||
self.disable_weights = True
|
||||
|
||||
|
||||
def _parse_prompts(text):
|
||||
"""Split comma-separated prompts with optional :N max detections per category"""
|
||||
text = text.replace("(", "").replace(")", "")
|
||||
parts = [p.strip() for p in text.split(",") if p.strip()]
|
||||
result = []
|
||||
for part in parts:
|
||||
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
|
||||
if m:
|
||||
text_part = m.group(1).strip()
|
||||
val = m.group(2)
|
||||
max_det = max(1, round(float(val)))
|
||||
result.append((text_part, max_det))
|
||||
else:
|
||||
result.append((part, 1))
|
||||
return result
|
||||
|
||||
|
||||
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
parsed = _parse_prompts(text)
|
||||
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
|
||||
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
# Tokenize each prompt part separately, store per-part batches and metadata
|
||||
inner = getattr(self, self.clip)
|
||||
per_prompt = []
|
||||
for prompt_text, max_det in parsed:
|
||||
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
|
||||
per_prompt.append((batches, max_det))
|
||||
# Main output uses first prompt's tokens (for compatibility)
|
||||
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
|
||||
return out
|
||||
|
||||
|
||||
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
|
||||
if per_prompt is None:
|
||||
return super().encode_token_weights(token_weight_pairs)
|
||||
|
||||
# Encode each prompt separately, pack into extra dict
|
||||
inner = getattr(self, self.clip)
|
||||
multi_cond = []
|
||||
first_pooled = None
|
||||
for batches, max_det in per_prompt:
|
||||
out = inner.encode_token_weights(batches)
|
||||
cond, pooled = out[0], out[1]
|
||||
extra = out[2] if len(out) > 2 else {}
|
||||
if first_pooled is None:
|
||||
first_pooled = pooled
|
||||
multi_cond.append({
|
||||
"cond": cond,
|
||||
"attention_mask": extra.get("attention_mask"),
|
||||
"max_detections": max_det,
|
||||
})
|
||||
|
||||
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
|
||||
main = multi_cond[0]
|
||||
main_extra = {}
|
||||
if main["attention_mask"] is not None:
|
||||
main_extra["attention_mask"] = main["attention_mask"]
|
||||
main_extra["sam3_multi_cond"] = multi_cond
|
||||
return (main["cond"], first_pooled, main_extra)
|
||||
|
|
@ -9,6 +9,7 @@ from comfy_api.latest._input import (
|
|||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
RangeInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -21,4 +22,5 @@ __all__ = [
|
|||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .range_types import RangeInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -12,4 +13,5 @@ __all__ = [
|
|||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RangeInput:
|
||||
"""Represents a levels/range adjustment: input range [min, max] with
|
||||
optional midpoint (gamma control).
|
||||
|
||||
Generates a 1D LUT identical to GIMP's levels mapping:
|
||||
1. Normalize input to [0, 1] using [min, max]
|
||||
2. Apply gamma correction: pow(value, 1/gamma)
|
||||
3. Clamp to [0, 1]
|
||||
|
||||
The midpoint field is a position in [0, 1] representing where the
|
||||
midtone falls within [min, max]. It maps to gamma via:
|
||||
gamma = -log2(midpoint)
|
||||
So midpoint=0.5 → gamma=1.0 (linear).
|
||||
"""
|
||||
|
||||
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.midpoint = midpoint
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> RangeInput:
|
||||
if isinstance(data, RangeInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return RangeInput(
|
||||
min_val=float(data.get("min", 0.0)),
|
||||
max_val=float(data.get("max", 1.0)),
|
||||
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
|
||||
)
|
||||
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table mapping [0, 1] input through this
|
||||
levels adjustment.
|
||||
|
||||
The LUT maps normalized input values (0..1) to output values (0..1),
|
||||
matching the GIMP levels formula.
|
||||
"""
|
||||
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
|
||||
|
||||
in_range = self.max_val - self.min_val
|
||||
if abs(in_range) < 1e-10:
|
||||
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
|
||||
|
||||
# Normalize: map [min, max] → [0, 1]
|
||||
result = (xs - self.min_val) / in_range
|
||||
result = np.clip(result, 0.0, 1.0)
|
||||
|
||||
# Gamma correction from midpoint
|
||||
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
|
||||
gamma = max(-math.log2(self.midpoint), 0.001)
|
||||
inv_gamma = 1.0 / gamma
|
||||
mask = result > 0
|
||||
result[mask] = np.power(result[mask], inv_gamma)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
|
||||
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"
|
||||
|
|
@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO):
|
|||
Type = list[int]
|
||||
|
||||
|
||||
@comfytype(io_type="RANGE")
|
||||
class Range(ComfyTypeIO):
|
||||
from comfy_api.input import RangeInput
|
||||
if TYPE_CHECKING:
|
||||
Type = RangeInput
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None,
|
||||
display: str=None,
|
||||
gradient_stops: list=None,
|
||||
show_midpoint: bool=None,
|
||||
midpoint_scale: str=None,
|
||||
value_min: float=None,
|
||||
value_max: float=None,
|
||||
advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = {"min": 0.0, "max": 1.0}
|
||||
self.display = display
|
||||
self.gradient_stops = gradient_stops
|
||||
self.show_midpoint = show_midpoint
|
||||
self.midpoint_scale = midpoint_scale
|
||||
self.value_min = value_min
|
||||
self.value_max = value_max
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"display": self.display,
|
||||
"gradient_stops": self.gradient_stops,
|
||||
"show_midpoint": self.show_midpoint,
|
||||
"midpoint_scale": self.midpoint_scale,
|
||||
"value_min": self.value_min,
|
||||
"value_max": self.value_max,
|
||||
})
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
|
|
@ -2276,5 +2313,6 @@ __all__ = [
|
|||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel):
|
|||
usage: TaskStatusUsage | None = Field(None)
|
||||
|
||||
|
||||
class GetAssetResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
name: str | None = Field(None)
|
||||
url: str | None = Field(None)
|
||||
asset_type: str = Field(...)
|
||||
group_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateVisualValidateSessionResponse(BaseModel):
|
||||
session_id: str = Field(...)
|
||||
h5_link: str = Field(...)
|
||||
|
||||
|
||||
class SeedanceGetVisualValidateSessionResponse(BaseModel):
|
||||
session_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
group_id: str | None = Field(None)
|
||||
error_code: str | None = Field(None)
|
||||
error_message: str | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateAssetRequest(BaseModel):
|
||||
group_id: str = Field(...)
|
||||
url: str = Field(...)
|
||||
asset_type: str = Field(...)
|
||||
name: str | None = Field(None, max_length=64)
|
||||
project_name: str | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateAssetResponse(BaseModel):
|
||||
asset_id: str = Field(...)
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
|
@ -11,9 +12,14 @@ from comfy_api_nodes.apis.bytedance import (
|
|||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
GetAssetResponse,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
SeedanceCreateAssetRequest,
|
||||
SeedanceCreateAssetResponse,
|
||||
SeedanceCreateVisualValidateSessionResponse,
|
||||
SeedanceGetVisualValidateSessionResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskAudioContent,
|
||||
|
|
@ -44,10 +50,16 @@ from comfy_api_nodes.util import (
|
|||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||
|
||||
_VERIFICATION_POLL_TIMEOUT_SEC = 120
|
||||
_VERIFICATION_POLL_INTERVAL_SEC = 3
|
||||
|
||||
SEEDREAM_MODELS = {
|
||||
"seedream 5.0 lite": "seedream-5-0-260128",
|
||||
"seedream-4-5-251128": "seedream-4-5-251128",
|
||||
|
|
@ -96,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
|
|||
)
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
|
||||
"""Look up each asset, validate Active status, group by asset_type.
|
||||
|
||||
Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://<asset_id>".
|
||||
"""
|
||||
image_assets: dict[str, str] = {}
|
||||
video_assets: dict[str, str] = {}
|
||||
audio_assets: dict[str, str] = {}
|
||||
for i, raw_id in enumerate(asset_ids, 1):
|
||||
asset_id = (raw_id or "").strip()
|
||||
if not asset_id:
|
||||
continue
|
||||
result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
|
||||
response_model=GetAssetResponse,
|
||||
)
|
||||
if result.status != "Active":
|
||||
extra = f" {result.error.code}: {result.error.message}" if result.error else ""
|
||||
raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}")
|
||||
asset_uri = f"asset://{asset_id}"
|
||||
if result.asset_type == "Image":
|
||||
image_assets[asset_id] = asset_uri
|
||||
elif result.asset_type == "Video":
|
||||
video_assets[asset_id] = asset_uri
|
||||
elif result.asset_type == "Audio":
|
||||
audio_assets[asset_id] = asset_uri
|
||||
return image_assets, video_assets, audio_assets
|
||||
|
||||
|
||||
_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _build_asset_labels(
|
||||
reference_assets: dict[str, str],
|
||||
image_asset_uris: dict[str, str],
|
||||
video_asset_uris: dict[str, str],
|
||||
audio_asset_uris: dict[str, str],
|
||||
n_reference_images: int,
|
||||
n_reference_videos: int,
|
||||
n_reference_audios: int,
|
||||
) -> dict[int, str]:
|
||||
"""Map asset slot number (from 'asset_N' keys) to its positional label.
|
||||
|
||||
Asset entries are appended to `content` after the reference_images/videos/audios,
|
||||
so their 1-indexed labels continue from the count of existing same-type refs:
|
||||
one reference_images entry + one Image-type asset -> asset labelled "Image 2".
|
||||
"""
|
||||
image_n = n_reference_images
|
||||
video_n = n_reference_videos
|
||||
audio_n = n_reference_audios
|
||||
labels: dict[int, str] = {}
|
||||
for slot_key, raw_id in reference_assets.items():
|
||||
asset_id = (raw_id or "").strip()
|
||||
if not asset_id:
|
||||
continue
|
||||
try:
|
||||
slot_num = int(slot_key.rsplit("_", 1)[-1])
|
||||
except ValueError:
|
||||
continue
|
||||
if asset_id in image_asset_uris:
|
||||
image_n += 1
|
||||
labels[slot_num] = f"Image {image_n}"
|
||||
elif asset_id in video_asset_uris:
|
||||
video_n += 1
|
||||
labels[slot_num] = f"Video {video_n}"
|
||||
elif asset_id in audio_asset_uris:
|
||||
audio_n += 1
|
||||
labels[slot_num] = f"Audio {audio_n}"
|
||||
return labels
|
||||
|
||||
|
||||
def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str:
|
||||
"""Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels."""
|
||||
if not labels:
|
||||
return prompt
|
||||
|
||||
def _sub(m: "re.Match[str]") -> str:
|
||||
return labels.get(int(m.group(1)), m.group(0))
|
||||
|
||||
return _ASSET_REF_RE.sub(_sub, prompt)
|
||||
|
||||
|
||||
async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str:
|
||||
session = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"),
|
||||
response_model=SeedanceCreateVisualValidateSessionResponse,
|
||||
)
|
||||
logger.warning("Seedance authentication required. Open link: %s", session.h5_link)
|
||||
|
||||
h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}"
|
||||
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
|
||||
response_model=SeedanceGetVisualValidateSessionResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
poll_interval=_VERIFICATION_POLL_INTERVAL_SEC,
|
||||
max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1,
|
||||
estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1,
|
||||
extra_text=h5_text,
|
||||
)
|
||||
|
||||
if not result.group_id:
|
||||
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
|
||||
|
||||
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
|
||||
)
|
||||
return result.group_id
|
||||
|
||||
|
||||
async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str:
|
||||
if group_id and group_id.strip():
|
||||
return group_id.strip()
|
||||
return await _obtain_group_id_via_h5_auth(cls)
|
||||
|
||||
|
||||
async def _create_seedance_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
group_id: str,
|
||||
url: str,
|
||||
name: str,
|
||||
asset_type: str,
|
||||
) -> str:
|
||||
req = SeedanceCreateAssetRequest(
|
||||
group_id=group_id,
|
||||
url=url,
|
||||
asset_type=asset_type,
|
||||
name=name or None,
|
||||
)
|
||||
result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=req,
|
||||
)
|
||||
return result.asset_id
|
||||
|
||||
|
||||
async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse:
|
||||
"""Poll the newly created asset until its status becomes Active."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
|
||||
response_model=GetAssetResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
completed_statuses=["Active"],
|
||||
failed_statuses=["Failed"],
|
||||
poll_interval=5,
|
||||
max_poll_attempts=1200,
|
||||
extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}",
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
|
|
@ -1228,12 +1403,27 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="First frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Image.Input(
|
||||
"last_frame",
|
||||
tooltip="Last frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"first_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the first frame. "
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"last_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the last frame. "
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
|
|
@ -1286,24 +1476,54 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
first_frame: Input.Image | None = None,
|
||||
last_frame: Input.Image | None = None,
|
||||
first_frame_asset_id: str = "",
|
||||
last_frame_asset_id: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
|
||||
first_frame_asset_id = first_frame_asset_id.strip()
|
||||
last_frame_asset_id = last_frame_asset_id.strip()
|
||||
|
||||
if first_frame is not None and first_frame_asset_id:
|
||||
raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.")
|
||||
if first_frame is None and not first_frame_asset_id:
|
||||
raise ValueError("Either first_frame or first_frame_asset_id is required.")
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
if asset_ids_to_resolve:
|
||||
image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve)
|
||||
for aid in asset_ids_to_resolve:
|
||||
if aid not in image_assets:
|
||||
raise ValueError(f"Asset {aid} is not an Image asset.")
|
||||
|
||||
if first_frame_asset_id:
|
||||
first_frame_url = image_assets[first_frame_asset_id]
|
||||
else:
|
||||
first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
),
|
||||
image_url=TaskImageContentUrl(url=first_frame_url),
|
||||
role="first_frame",
|
||||
),
|
||||
]
|
||||
if last_frame is not None:
|
||||
if last_frame_asset_id:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]),
|
||||
role="last_frame",
|
||||
),
|
||||
)
|
||||
elif last_frame is not None:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
|
|
@ -1385,6 +1605,24 @@ def _seedance2_reference_inputs(resolutions: list[str]):
|
|||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_assets",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.String.Input("reference_asset"),
|
||||
names=[
|
||||
"asset_1",
|
||||
"asset_2",
|
||||
"asset_3",
|
||||
"asset_4",
|
||||
"asset_5",
|
||||
"asset_6",
|
||||
"asset_7",
|
||||
"asset_8",
|
||||
"asset_9",
|
||||
],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -1486,24 +1724,42 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||
reference_images = model.get("reference_images", {})
|
||||
reference_videos = model.get("reference_videos", {})
|
||||
reference_audios = model.get("reference_audios", {})
|
||||
reference_assets = model.get("reference_assets", {})
|
||||
|
||||
if not reference_images and not reference_videos:
|
||||
raise ValueError("At least one reference image or video is required.")
|
||||
reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets(
|
||||
cls, list(reference_assets.values())
|
||||
)
|
||||
|
||||
if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets:
|
||||
raise ValueError("At least one reference image or video or asset is required.")
|
||||
|
||||
total_images = len(reference_images) + len(reference_image_assets)
|
||||
if total_images > 9:
|
||||
raise ValueError(
|
||||
f"Too many reference images: {total_images} "
|
||||
f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9."
|
||||
)
|
||||
total_videos = len(reference_videos) + len(reference_video_assets)
|
||||
if total_videos > 3:
|
||||
raise ValueError(
|
||||
f"Too many reference videos: {total_videos} "
|
||||
f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3."
|
||||
)
|
||||
total_audios = len(reference_audios) + len(reference_audio_assets)
|
||||
if total_audios > 3:
|
||||
raise ValueError(
|
||||
f"Too many reference audios: {total_audios} "
|
||||
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
|
||||
)
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = len(reference_videos) > 0
|
||||
has_video_input = total_videos > 0
|
||||
|
||||
if model.get("auto_downscale") and reference_videos:
|
||||
max_px = (
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
|
||||
.get(model["resolution"], {})
|
||||
.get("max")
|
||||
)
|
||||
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
|
||||
if max_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = resize_video_to_pixel_budget(
|
||||
reference_videos[key], max_px
|
||||
)
|
||||
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
|
||||
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
|
|
@ -1531,8 +1787,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||
if total_audio_duration > 15.1:
|
||||
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
asset_labels = _build_asset_labels(
|
||||
reference_assets,
|
||||
reference_image_assets,
|
||||
reference_video_assets,
|
||||
reference_audio_assets,
|
||||
len(reference_images),
|
||||
len(reference_videos),
|
||||
len(reference_audios),
|
||||
)
|
||||
prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels)
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskTextContent(text=prompt_text),
|
||||
]
|
||||
for i, key in enumerate(reference_images, 1):
|
||||
content.append(
|
||||
|
|
@ -1573,6 +1840,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||
),
|
||||
),
|
||||
)
|
||||
for url in reference_image_assets.values():
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(url=url),
|
||||
role="reference_image",
|
||||
),
|
||||
)
|
||||
for url in reference_video_assets.values():
|
||||
content.append(
|
||||
TaskVideoContent(video_url=TaskVideoContentUrl(url=url)),
|
||||
)
|
||||
for url in reference_audio_assets.values():
|
||||
content.append(
|
||||
TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)),
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
|
|
@ -1627,6 +1909,156 @@ async def process_video_task(
|
|||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDanceCreateImageAsset(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateImageAsset",
|
||||
display_name="ByteDance Create Image Asset",
|
||||
category="api node/image/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
"H5 authentication flow to create a new group before adding the asset."
|
||||
),
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Image to register as a personal asset."),
|
||||
IO.String.Input(
|
||||
"group_id",
|
||||
default="",
|
||||
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
|
||||
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
|
||||
),
|
||||
# IO.String.Input(
|
||||
# "name",
|
||||
# default="",
|
||||
# tooltip="Asset name (up to 64 characters).",
|
||||
# ),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="asset_id"),
|
||||
IO.String.Output(display_name="group_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
# is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
group_id: str = "",
|
||||
# name: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
# if len(name) > 64:
|
||||
# raise ValueError("Name of asset can not be greater then 64 symbols")
|
||||
validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000)
|
||||
validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1))
|
||||
resolved_group = await _resolve_group_id(cls, group_id)
|
||||
asset_id = await _create_seedance_asset(
|
||||
cls,
|
||||
group_id=resolved_group,
|
||||
url=await upload_image_to_comfyapi(cls, image),
|
||||
name="",
|
||||
asset_type="Image",
|
||||
)
|
||||
await _wait_for_asset_active(cls, asset_id, resolved_group)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
|
||||
f"group_id: {resolved_group}",
|
||||
cls.hidden.unique_id,
|
||||
)
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateVideoAsset",
|
||||
display_name="ByteDance Create Video Asset",
|
||||
category="api node/video/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
"H5 authentication flow to create a new group before adding the asset."
|
||||
),
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Video to register as a personal asset."),
|
||||
IO.String.Input(
|
||||
"group_id",
|
||||
default="",
|
||||
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
|
||||
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
|
||||
),
|
||||
# IO.String.Input(
|
||||
# "name",
|
||||
# default="",
|
||||
# tooltip="Asset name (up to 64 characters).",
|
||||
# ),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="asset_id"),
|
||||
IO.String.Output(display_name="group_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
# is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
group_id: str = "",
|
||||
# name: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
# if len(name) > 64:
|
||||
# raise ValueError("Name of asset can not be greater then 64 symbols")
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000)
|
||||
|
||||
w, h = video.get_dimensions()
|
||||
if h > 0:
|
||||
ratio = w / h
|
||||
if not (0.4 <= ratio <= 2.5):
|
||||
raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).")
|
||||
pixels = w * h
|
||||
if not (409_600 <= pixels <= 927_408):
|
||||
raise ValueError(
|
||||
f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})."
|
||||
)
|
||||
|
||||
fps = float(video.get_frame_rate())
|
||||
if not (24 <= fps <= 60):
|
||||
raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.")
|
||||
|
||||
resolved_group = await _resolve_group_id(cls, group_id)
|
||||
asset_id = await _create_seedance_asset(
|
||||
cls,
|
||||
group_id=resolved_group,
|
||||
url=await upload_video_to_comfyapi(cls, video),
|
||||
name="",
|
||||
asset_type="Video",
|
||||
)
|
||||
await _wait_for_asset_active(cls, asset_id, resolved_group)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
|
||||
f"group_id: {resolved_group}",
|
||||
cls.hidden.unique_id,
|
||||
)
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
|
@ -1640,6 +2072,8 @@ class ByteDanceExtension(ComfyExtension):
|
|||
ByteDance2TextToVideoNode,
|
||||
ByteDance2FirstLastFrameNode,
|
||||
ByteDance2ReferenceNode,
|
||||
ByteDanceCreateImageAsset,
|
||||
ByteDanceCreateVideoAsset,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
|||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
|
@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
|
||||
IO.DynamicCombo.Input(
|
||||
"storyboards",
|
||||
options=[
|
||||
|
|
@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$res := widgets.resolution;
|
||||
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
|
||||
$isV3 := $contains(widgets.model_name, "v3");
|
||||
$audio := $isV3 and widgets.generate_audio;
|
||||
$rates := $audio
|
||||
? {"std": 0.112, "pro": 0.14}
|
||||
: {"std": 0.084, "pro": 0.112};
|
||||
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
|
||||
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
|
|
@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
|
||||
if generate_audio:
|
||||
raise ValueError("kling-video-o1 does not support audio generation.")
|
||||
if resolution == "4k":
|
||||
raise ValueError("kling-video-o1 does not support 4k resolution.")
|
||||
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
|
||||
if stories_enabled and model_name == "kling-video-o1":
|
||||
raise ValueError("kling-video-o1 does not support storyboards.")
|
||||
|
|
@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
f"must equal the global duration ({duration}s)."
|
||||
)
|
||||
|
||||
if resolution == "4k":
|
||||
mode = "4k"
|
||||
elif resolution == "1080p":
|
||||
mode = "pro"
|
||||
else:
|
||||
mode = "std"
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
|
|
@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
mode=mode,
|
||||
multi_shot=multi_shot,
|
||||
multi_prompt=multi_prompt_list,
|
||||
shot_type="customize" if multi_shot else None,
|
||||
|
|
@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
|
||||
IO.DynamicCombo.Input(
|
||||
"storyboards",
|
||||
options=[
|
||||
|
|
@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$res := widgets.resolution;
|
||||
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
|
||||
$isV3 := $contains(widgets.model_name, "v3");
|
||||
$audio := $isV3 and widgets.generate_audio;
|
||||
$rates := $audio
|
||||
? {"std": 0.112, "pro": 0.14}
|
||||
: {"std": 0.084, "pro": 0.112};
|
||||
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
|
||||
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
|
|
@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
|
||||
if generate_audio:
|
||||
raise ValueError("kling-video-o1 does not support audio generation.")
|
||||
if resolution == "4k":
|
||||
raise ValueError("kling-video-o1 does not support 4k resolution.")
|
||||
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
|
||||
if stories_enabled and model_name == "kling-video-o1":
|
||||
raise ValueError("kling-video-o1 does not support storyboards.")
|
||||
|
|
@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
if resolution == "4k":
|
||||
mode = "4k"
|
||||
elif resolution == "1080p":
|
||||
mode = "pro"
|
||||
else:
|
||||
mode = "std"
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
|
|
@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
mode=mode,
|
||||
sound="on" if generate_audio else "off",
|
||||
multi_shot=multi_shot,
|
||||
multi_prompt=multi_prompt_list,
|
||||
|
|
@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
|
||||
IO.DynamicCombo.Input(
|
||||
"storyboards",
|
||||
options=[
|
||||
|
|
@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$res := widgets.resolution;
|
||||
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
|
||||
$isV3 := $contains(widgets.model_name, "v3");
|
||||
$audio := $isV3 and widgets.generate_audio;
|
||||
$rates := $audio
|
||||
? {"std": 0.112, "pro": 0.14}
|
||||
: {"std": 0.084, "pro": 0.112};
|
||||
? {"std": 0.112, "pro": 0.14, "4k": 0.42}
|
||||
: {"std": 0.084, "pro": 0.112, "4k": 0.42};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
|
|
@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
|
||||
if generate_audio:
|
||||
raise ValueError("kling-video-o1 does not support audio generation.")
|
||||
if resolution == "4k":
|
||||
raise ValueError("kling-video-o1 does not support 4k resolution.")
|
||||
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
|
||||
if stories_enabled and model_name == "kling-video-o1":
|
||||
raise ValueError("kling-video-o1 does not support storyboards.")
|
||||
|
|
@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
image_list: list[OmniParamImage] = []
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
if resolution == "4k":
|
||||
mode = "4k"
|
||||
elif resolution == "1080p":
|
||||
mode = "pro"
|
||||
else:
|
||||
mode = "std"
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
|
|
@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
mode=mode,
|
||||
sound="on" if generate_audio else "off",
|
||||
multi_shot=multi_shot,
|
||||
multi_prompt=multi_prompt_list,
|
||||
|
|
@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode):
|
|||
IO.DynamicCombo.Option(
|
||||
"kling-v3",
|
||||
[
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"]),
|
||||
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
|
|
@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode):
|
|||
),
|
||||
expr="""
|
||||
(
|
||||
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
|
||||
$rates := {
|
||||
"4k": {"off": 0.42, "on": 0.42},
|
||||
"1080p": {"off": 0.112, "on": 0.168},
|
||||
"720p": {"off": 0.084, "on": 0.126}
|
||||
};
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$audio := widgets.generate_audio ? "on" : "off";
|
||||
$rate := $lookup($lookup($rates, $res), $audio);
|
||||
|
|
@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode):
|
|||
start_frame: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
mode = "pro" if model["resolution"] == "1080p" else "std"
|
||||
if model["resolution"] == "4k":
|
||||
mode = "4k"
|
||||
elif model["resolution"] == "1080p":
|
||||
mode = "pro"
|
||||
else:
|
||||
mode = "std"
|
||||
custom_multi_shot = False
|
||||
if multi_shot["multi_shot"] == "disabled":
|
||||
shot_type = None
|
||||
|
|
@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode):
|
|||
cls,
|
||||
ApiEndpoint(path=poll_path),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
|
@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
|||
IO.DynamicCombo.Option(
|
||||
"kling-v3",
|
||||
[
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"]),
|
||||
IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
|
||||
],
|
||||
),
|
||||
],
|
||||
|
|
@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
|||
),
|
||||
expr="""
|
||||
(
|
||||
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
|
||||
$rates := {
|
||||
"4k": {"off": 0.42, "on": 0.42},
|
||||
"1080p": {"off": 0.112, "on": 0.168},
|
||||
"720p": {"off": 0.084, "on": 0.126}
|
||||
};
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$audio := widgets.generate_audio ? "on" : "off";
|
||||
$rate := $lookup($lookup($rates, $res), $audio);
|
||||
|
|
@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
|||
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
|
||||
image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
|
||||
image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
|
||||
if model["resolution"] == "4k":
|
||||
mode = "4k"
|
||||
elif model["resolution"] == "1080p":
|
||||
mode = "pro"
|
||||
else:
|
||||
mode = "std"
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||
|
|
@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
|||
image=image_url,
|
||||
image_tail=image_tail_url,
|
||||
prompt=prompt,
|
||||
mode="pro" if model["resolution"] == "1080p" else "std",
|
||||
mode=mode,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
|
|
@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
|||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
|
|
|||
|
|
@ -357,6 +357,10 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
|
|||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||
|
||||
|
||||
def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
|
||||
|
||||
|
||||
class OpenAIGPTImage1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
|
|
@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
options=[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
],
|
||||
tooltip="Image size",
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -427,7 +441,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5", 'gpt-image-2'],
|
||||
options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"],
|
||||
default="gpt-image-2",
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.046, 0.07],
|
||||
"high": [0.167, 0.3]
|
||||
"gpt-image-1": {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.042, 0.07],
|
||||
"high": [0.167, 0.25]
|
||||
},
|
||||
"gpt-image-1.5": {
|
||||
"low": [0.009, 0.02],
|
||||
"medium": [0.034, 0.062],
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.012],
|
||||
"medium": [0.041, 0.112],
|
||||
"high": [0.165, 0.43]
|
||||
}
|
||||
};
|
||||
$range := $lookup($ranges, widgets.quality);
|
||||
$n := widgets.n;
|
||||
$range := $lookup($lookup($ranges, widgets.model), widgets.quality);
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0],
|
||||
"max_usd": $range[1],
|
||||
"format": { "suffix": " x " & $string($n) & "/Run" }
|
||||
"min_usd": $range[0] * $n,
|
||||
"max_usd": $range[1] * $n,
|
||||
"format": { "suffix": "/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
|
|
@ -483,12 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if model in ("gpt-image-1", "gpt-image-1.5"):
|
||||
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
|
||||
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
|
||||
|
||||
if model == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
elif model == "gpt-image-2":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
price_extractor = calculate_tokens_price_image_2_0
|
||||
if background == "transparent":
|
||||
raise ValueError("Transparent background is not supported for GPT Image 2 model")
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
|
|
|
|||
|
|
@ -156,6 +156,7 @@ async def poll_op(
|
|||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
|
|
@ -176,6 +177,7 @@ async def poll_op(
|
|||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
extra_text=extra_text,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
|
|
@ -260,6 +262,7 @@ async def poll_op_raw(
|
|||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
|
|
@ -299,6 +302,7 @@ async def poll_op_raw(
|
|||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
|
|
@ -389,6 +393,7 @@ async def poll_op_raw(
|
|||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
return resp_json
|
||||
|
||||
|
|
@ -462,6 +467,7 @@ def _display_time_progress(
|
|||
price: float | None = None,
|
||||
is_queued: bool | None = None,
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
extra_text: str | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
|
|
@ -469,7 +475,8 @@ def _display_time_progress(
|
|||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
|
||||
_display_text(node_cls, text, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import torchaudio
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.samplers
|
||||
|
|
@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode):
|
|||
@classmethod
|
||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||
# Encode reference audio to latents and patchify
|
||||
audio_latents = audio_vae.encode(reference_audio)
|
||||
sample_rate = reference_audio["sample_rate"]
|
||||
vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate)
|
||||
else:
|
||||
waveform = reference_audio["waveform"]
|
||||
|
||||
audio_latents = audio_vae.encode(waveform.movedim(1, -1))
|
||||
b, c, t, f = audio_latents.shape
|
||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||
ref_audio = {"tokens": ref_tokens}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import torch
|
||||
|
||||
# Preview Any - original implement from
|
||||
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
|
||||
|
|
@ -19,6 +20,7 @@ class PreviewAny():
|
|||
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
|
||||
|
||||
def main(self, source=None):
|
||||
torch.set_printoptions(edgeitems=6)
|
||||
value = 'None'
|
||||
if isinstance(source, str):
|
||||
value = source
|
||||
|
|
@ -33,6 +35,7 @@ class PreviewAny():
|
|||
except Exception:
|
||||
value = 'source exists, but could not be serialized.'
|
||||
|
||||
torch.set_printoptions()
|
||||
return {"ui": {"text": (value,)}, "result": (value,)}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,529 @@
|
|||
"""
|
||||
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
import av
|
||||
from fractions import Fraction
|
||||
|
||||
|
||||
def _extract_text_prompts(conditioning, device, dtype):
|
||||
"""Extract list of (text_embeddings, text_mask) from conditioning."""
|
||||
cond_meta = conditioning[0][1]
|
||||
multi = cond_meta.get("sam3_multi_cond")
|
||||
prompts = []
|
||||
if multi is not None:
|
||||
for entry in multi:
|
||||
emb = entry["cond"].to(device=device, dtype=dtype)
|
||||
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
|
||||
if mask is None:
|
||||
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
||||
prompts.append((emb, mask, entry.get("max_detections", 1)))
|
||||
else:
|
||||
emb = conditioning[0][0].to(device=device, dtype=dtype)
|
||||
mask = cond_meta.get("attention_mask")
|
||||
if mask is not None:
|
||||
mask = mask.to(device)
|
||||
else:
|
||||
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
||||
prompts.append((emb, mask, 1))
|
||||
return prompts
|
||||
|
||||
|
||||
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
|
||||
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
|
||||
|
||||
Returns: [1, H, W] binary mask
|
||||
"""
|
||||
def _coarse_fallback():
|
||||
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
|
||||
mode="bilinear", align_corners=False)[0] > 0).float()
|
||||
|
||||
if iterations <= 0:
|
||||
return _coarse_fallback()
|
||||
|
||||
pad_frac = 0.1
|
||||
x1, y1, x2, y2 = box_xyxy.tolist()
|
||||
bw, bh = x2 - x1, y2 - y1
|
||||
cx1 = max(0, int(x1 - bw * pad_frac))
|
||||
cy1 = max(0, int(y1 - bh * pad_frac))
|
||||
cx2 = min(W, int(x2 + bw * pad_frac))
|
||||
cy2 = min(H, int(y2 + bh * pad_frac))
|
||||
if cx2 <= cx1 or cy2 <= cy1:
|
||||
return _coarse_fallback()
|
||||
|
||||
crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3]
|
||||
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
||||
crop_frame = crop_1008.to(device=device, dtype=dtype)
|
||||
crop_h, crop_w = cy2 - cy1, cx2 - cx1
|
||||
|
||||
# Crop coarse mask and refine via SAM on the cropped image
|
||||
mask_h, mask_w = coarse_mask.shape[-2:]
|
||||
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
||||
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
||||
if mx2 <= mx1 or my2 <= my1:
|
||||
return _coarse_fallback()
|
||||
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
||||
for _ in range(iterations):
|
||||
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
||||
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
|
||||
|
||||
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
|
||||
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
|
||||
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
|
||||
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
|
||||
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
|
||||
|
||||
|
||||
|
||||
class SAM3_Detect(io.ComfyNode):
|
||||
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_Detect",
|
||||
display_name="SAM3 Detect",
|
||||
category="detection/",
|
||||
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Image.Input("image", display_name="image"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
|
||||
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
|
||||
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
||||
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
||||
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
|
||||
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output("masks"),
|
||||
io.BoundingBox.Output("bboxes"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
||||
|
||||
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
|
||||
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
|
||||
def _boxes_to_tensor(box_list):
|
||||
coords = []
|
||||
for d in box_list:
|
||||
cx = (d["x"] + d["width"] / 2) / W
|
||||
cy = (d["y"] + d["height"] / 2) / H
|
||||
coords.append([cx, cy, d["width"] / W, d["height"] / H])
|
||||
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
|
||||
|
||||
per_frame_boxes = None
|
||||
if bboxes is not None:
|
||||
if isinstance(bboxes, dict):
|
||||
# Single box → same for all frames
|
||||
shared = _boxes_to_tensor([bboxes])
|
||||
per_frame_boxes = [shared] * B
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
|
||||
# list[list[dict]] → per-frame boxes
|
||||
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
|
||||
# Pad to B if fewer frames provided
|
||||
while len(per_frame_boxes) < B:
|
||||
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0:
|
||||
# list[dict] → same boxes for all frames
|
||||
shared = _boxes_to_tensor(bboxes)
|
||||
per_frame_boxes = [shared] * B
|
||||
|
||||
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
|
||||
pos_pts = json.loads(positive_coords) if positive_coords else []
|
||||
neg_pts = json.loads(negative_coords) if negative_coords else []
|
||||
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
# Build point inputs for tracker SAM decoder path
|
||||
point_inputs = None
|
||||
if has_points:
|
||||
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
|
||||
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
|
||||
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
|
||||
point_inputs = {
|
||||
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
|
||||
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
|
||||
}
|
||||
|
||||
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
|
||||
has_text = len(cond_list) > 0
|
||||
|
||||
# Run per-image through detector (text/boxes) and/or tracker (points)
|
||||
all_bbox_dicts = []
|
||||
all_masks = []
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
|
||||
for b in range(B):
|
||||
frame = image_in[b:b+1].to(device=device, dtype=dtype)
|
||||
b_boxes = None
|
||||
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
|
||||
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
|
||||
|
||||
frame_bbox_dicts = []
|
||||
frame_masks = []
|
||||
|
||||
# Point prompts: tracker SAM decoder path with iterative refinement
|
||||
if point_inputs is not None:
|
||||
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
|
||||
for _ in range(max(0, refine_iterations - 1)):
|
||||
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
||||
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
||||
frame_masks.append((mask[0] > 0).float())
|
||||
|
||||
# Box prompts: SAM decoder path (segment inside each box)
|
||||
if b_boxes is not None and not has_text:
|
||||
for box_cxcywh in b_boxes[0]:
|
||||
cx, cy, bw, bh = box_cxcywh.tolist()
|
||||
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
|
||||
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
|
||||
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
|
||||
device=device, dtype=dtype)
|
||||
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
|
||||
for _ in range(max(0, refine_iterations - 1)):
|
||||
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
||||
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
||||
frame_masks.append((mask[0] > 0).float())
|
||||
|
||||
# Text prompts: run detector per text prompt (each detects one category)
|
||||
for text_embeddings, text_mask, max_det in cond_list:
|
||||
results = sam3_model(
|
||||
frame, text_embeddings=text_embeddings, text_mask=text_mask,
|
||||
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
|
||||
|
||||
pred_boxes = results["boxes"][0]
|
||||
scores = results["scores"][0]
|
||||
masks = results["masks"][0]
|
||||
|
||||
probs = scores.sigmoid()
|
||||
keep = probs > threshold
|
||||
kept_boxes = pred_boxes[keep].cpu()
|
||||
kept_scores = probs[keep].cpu()
|
||||
kept_masks = masks[keep]
|
||||
|
||||
order = kept_scores.argsort(descending=True)[:max_det]
|
||||
kept_boxes = kept_boxes[order]
|
||||
kept_scores = kept_scores[order]
|
||||
kept_masks = kept_masks[order]
|
||||
|
||||
for box, score in zip(kept_boxes, kept_scores):
|
||||
frame_bbox_dicts.append({
|
||||
"x": float(box[0]), "y": float(box[1]),
|
||||
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
|
||||
"score": float(score),
|
||||
})
|
||||
for m, box in zip(kept_masks, kept_boxes):
|
||||
frame_masks.append(_refine_mask(
|
||||
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
|
||||
|
||||
all_bbox_dicts.append(frame_bbox_dicts)
|
||||
if len(frame_masks) > 0:
|
||||
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
|
||||
if individual_masks:
|
||||
all_masks.append(combined)
|
||||
else:
|
||||
all_masks.append((combined > 0).any(dim=0).float())
|
||||
else:
|
||||
if individual_masks:
|
||||
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
|
||||
else:
|
||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||
pbar.update(1)
|
||||
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
all_masks = [m.to(idev) for m in all_masks]
|
||||
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
||||
return io.NodeOutput(mask_out, all_bbox_dicts)
|
||||
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
class SAM3_VideoTrack(io.ComfyNode):
|
||||
"""Track objects across video frames using SAM3's memory-based tracker."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_VideoTrack",
|
||||
display_name="SAM3 Video Track",
|
||||
category="detection/",
|
||||
search_aliases=["sam3", "video", "track", "propagate"],
|
||||
inputs=[
|
||||
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
||||
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
||||
],
|
||||
outputs=[
|
||||
SAM3TrackData.Output("track_data", display_name="track_data"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
|
||||
N, H, W, C = images.shape
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
frames = images[..., :3].movedim(-1, 1)
|
||||
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
||||
|
||||
init_masks = None
|
||||
if initial_mask is not None:
|
||||
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(N)
|
||||
|
||||
text_prompts = None
|
||||
if conditioning is not None and len(conditioning) > 0:
|
||||
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
||||
elif initial_mask is None:
|
||||
raise ValueError("Either initial_mask or conditioning must be provided")
|
||||
|
||||
result = sam3_model.forward_video(
|
||||
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
||||
new_det_thresh=detection_threshold, max_objects=max_objects,
|
||||
detect_interval=detect_interval)
|
||||
result["orig_size"] = (H, W)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class SAM3_TrackPreview(io.ComfyNode):
|
||||
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackPreview",
|
||||
display_name="SAM3 Track Preview",
|
||||
category="detection/",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.Image.Input("images", display_name="images", optional=True),
|
||||
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
|
||||
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
|
||||
],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
COLORS = [
|
||||
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
|
||||
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
|
||||
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
|
||||
]
|
||||
|
||||
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
|
||||
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
|
||||
|
||||
@staticmethod
|
||||
def _get_glyphs(device, scale=3):
|
||||
key = (device, scale)
|
||||
if key in SAM3_TrackPreview._glyph_cache:
|
||||
return SAM3_TrackPreview._glyph_cache[key]
|
||||
atlas = torch.tensor([
|
||||
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
|
||||
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
||||
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
|
||||
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
|
||||
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
|
||||
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
|
||||
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
||||
], dtype=torch.bool)
|
||||
glyphs, outlines = [], []
|
||||
for d in range(10):
|
||||
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
|
||||
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
|
||||
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
|
||||
glyphs.append(g.to(device))
|
||||
outlines.append(o.to(device))
|
||||
gh, gw = glyphs[0].shape
|
||||
oh, ow = outlines[0].shape
|
||||
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
|
||||
return SAM3_TrackPreview._glyph_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
|
||||
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
|
||||
H, W = frame.shape[:2]
|
||||
device = frame.device
|
||||
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
|
||||
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
|
||||
digs = [int(d) for d in str(number)]
|
||||
total_w = len(digs) * (gw + scale) - scale
|
||||
x0 = cx - total_w // 2
|
||||
y0 = cy - gh // 2
|
||||
for i, d in enumerate(digs):
|
||||
dx = x0 + i * (gw + scale)
|
||||
# Black outline
|
||||
oy0, ox0 = y0 - 1, dx - 1
|
||||
osy1, osx1 = max(0, -oy0), max(0, -ox0)
|
||||
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
|
||||
if osy2 > osy1 and osx2 > osx1:
|
||||
fy1, fx1 = oy0 + osy1, ox0 + osx1
|
||||
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
|
||||
# Colored fill
|
||||
sy1, sx1 = max(0, -y0), max(0, -dx)
|
||||
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
|
||||
if sy2 > sy1 and sx2 > sx1:
|
||||
fy1, fx1 = y0 + sy1, dx + sx1
|
||||
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
|
||||
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
if images is not None:
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
if packed is None:
|
||||
N, N_obj = track_data["n_frames"], 0
|
||||
else:
|
||||
N, N_obj = packed.shape[0], packed.shape[1]
|
||||
|
||||
import uuid
|
||||
gpu = comfy.model_management.get_torch_device()
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
with av.open(filepath, mode='w') as output:
|
||||
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
|
||||
stream.width = W
|
||||
stream.height = H
|
||||
stream.pix_fmt = 'yuv420p'
|
||||
|
||||
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
|
||||
frame_np = frame_cpu.numpy()
|
||||
if N_obj > 0:
|
||||
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
|
||||
device=gpu, dtype=torch.float32)
|
||||
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
|
||||
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
|
||||
for t in range(N):
|
||||
if images is not None and t < images.shape[0]:
|
||||
frame = images[t].clone()
|
||||
else:
|
||||
frame = torch.zeros(H, W, 3)
|
||||
|
||||
if N_obj > 0:
|
||||
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
|
||||
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
|
||||
frame_gpu = frame.to(gpu)
|
||||
bool_masks = frame_masks > 0.5
|
||||
any_mask = bool_masks.any(dim=0)
|
||||
if any_mask.any():
|
||||
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
|
||||
color_overlay = colors_t[obj_idx_map]
|
||||
mask_3d = any_mask.unsqueeze(-1)
|
||||
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
|
||||
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
|
||||
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
|
||||
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
||||
has = area > 1
|
||||
scores = track_data.get("scores", [])
|
||||
for obj_idx in range(N_obj):
|
||||
if has[obj_idx]:
|
||||
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
||||
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
||||
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
||||
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
||||
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
||||
else:
|
||||
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
||||
|
||||
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
||||
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
|
||||
output.mux(stream.encode(None))
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
|
||||
|
||||
|
||||
class SAM3_TrackToMask(io.ComfyNode):
|
||||
"""Select tracked objects by index and output as mask."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackToMask",
|
||||
display_name="SAM3 Track to Mask",
|
||||
category="detection/",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.String.Input("object_indices", display_name="object_indices", default="",
|
||||
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output("masks", display_name="masks"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
|
||||
if packed is None:
|
||||
N = track_data["n_frames"]
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
N, N_obj = packed.shape[0], packed.shape[1]
|
||||
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
indices = [i for i in indices if 0 <= i < N_obj]
|
||||
else:
|
||||
indices = list(range(N_obj))
|
||||
|
||||
if not indices:
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
selected = packed[:, indices]
|
||||
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
||||
union = binary.any(dim=1, keepdim=True).float()
|
||||
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
||||
return io.NodeOutput(mask_out)
|
||||
|
||||
|
||||
class SAM3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SAM3_Detect,
|
||||
SAM3_VideoTrack,
|
||||
SAM3_TrackPreview,
|
||||
SAM3_TrackToMask,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SAM3Extension:
|
||||
return SAM3Extension()
|
||||
38
execution.py
38
execution.py
|
|
@ -811,11 +811,30 @@ class PromptExecutor:
|
|||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
if visiting is None:
|
||||
visiting = []
|
||||
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
|
||||
if unique_id in visiting:
|
||||
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
|
||||
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
|
||||
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
|
||||
for node_id in cycle_nodes:
|
||||
validated[node_id] = (False, [{
|
||||
"type": "dependency_cycle",
|
||||
"message": "Dependency cycle detected",
|
||||
"details": cycle_path,
|
||||
"extra_info": {
|
||||
"node_id": node_id,
|
||||
"cycle_nodes": cycle_nodes,
|
||||
}
|
||||
}], node_id)
|
||||
return validated[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
|
@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||
visiting.append(unique_id)
|
||||
try:
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting)
|
||||
finally:
|
||||
visiting.pop()
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
|
|
@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
else:
|
||||
ret = (True, [], unique_id)
|
||||
ret = validated.get(unique_id, (True, [], unique_id))
|
||||
# Recursive cycle detection may have already populated an error on us. Join it.
|
||||
ret = (
|
||||
ret[0] and valid is True and not errors,
|
||||
ret[1] + [error for error in errors if error not in ret[1]],
|
||||
unique_id,
|
||||
)
|
||||
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
comfyui_manager==4.1
|
||||
comfyui_manager==4.2.1
|
||||
|
|
|
|||
1
nodes.py
1
nodes.py
|
|
@ -2459,6 +2459,7 @@ async def init_builtin_extra_nodes():
|
|||
"nodes_curve.py",
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
comfyui-frontend-package==1.42.14
|
||||
comfyui-workflow-templates==0.9.59
|
||||
comfyui-embedded-docs==0.4.3
|
||||
comfyui-workflow-templates==0.9.62
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
|
@ -23,7 +23,7 @@ SQLAlchemy>=2.0
|
|||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.12
|
||||
comfy-aimdo==0.2.14
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_required_packages_versions():
|
|||
if len(s) == 2:
|
||||
version_str = s[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}")
|
||||
continue
|
||||
out[s[0]] = version_str
|
||||
return out.copy()
|
||||
|
|
|
|||
Loading…
Reference in New Issue