ComfyUI Nodes

for #54
pull/75/head
Alex "mcmonkey" Goodwin 2023-09-21 10:28:22 -07:00
parent 55ca687fb8
commit 1fd915af73
4 changed files with 92 additions and 6 deletions

6
__init__.py Normal file
View File

@ -0,0 +1,6 @@
from . import dynthres_comfyui
NODE_CLASS_MAPPINGS = {
"DynamicThresholdingSimple": dynthres_comfyui.DynamicThresholdingSimpleComfyNode,
"DynamicThresholdingFull": dynthres_comfyui.DynamicThresholdingComfyNode,
}

76
dynthres_comfyui.py Normal file
View File

@ -0,0 +1,76 @@
from .dynthres_core import DynThresh
class DynamicThresholdingComfyNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"mimic_scale": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step": 0.5}),
"threshold_percentile": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"mimic_mode": (DynThresh.Modes, ),
"mimic_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}),
"cfg_mode": (DynThresh.Modes, ),
"cfg_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}),
"sched_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"separate_feature_channels": (["enable", "disable"], ),
"scaling_startpoint": (DynThresh.Startpoints, ),
"variability_measure": (DynThresh.Variabilities, ),
"interpolate_phi": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/mcmonkey"
def patch(self, model, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, 0, 999, separate_feature_channels == "enable", scaling_startpoint, variability_measure, interpolate_phi)
def sampler_dyn_thrash(args):
x_out = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
time_step = args["timestep"]
dynamic_thresh.step = time_step[0]
return dynamic_thresh.dynthresh(x_out, uncond, cond_scale, None)
m = model.clone()
m.set_model_sampler_cfg_function(sampler_dyn_thrash)
return (m, )
class DynamicThresholdingSimpleComfyNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"mimic_scale": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step": 0.5}),
"threshold_percentile": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/mcmonkey"
def patch(self, model, mimic_scale, threshold_percentile):
dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, "CONSTANT", 0, "CONSTANT", 0, 0, 0, 999, False, "MEAN", "AD", 1)
def sampler_dyn_thrash(args):
x_out = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
time_step = args["timestep"]
dynamic_thresh.step = time_step[0]
return dynamic_thresh.dynthresh(x_out, uncond, cond_scale, None)
m = model.clone()
m.set_model_sampler_cfg_function(sampler_dyn_thrash)
return (m, )

View File

@ -3,6 +3,11 @@ import torch, math
######################### DynThresh Core #########################
class DynThresh:
Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
Startpoints = ["MEAN", "ZERO"]
Variabilities = ["AD", "STD"]
def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, maxSteps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
self.mimic_scale = mimic_scale
self.threshold_percentile = threshold_percentile

View File

@ -27,7 +27,6 @@ except Exception as e:
IS_AUTO_16 = False
######################### Data values #########################
VALID_MODES = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
MODES_WITH_VALUE = ["Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
######################### Script class entrypoint #########################
@ -53,8 +52,8 @@ class Script(scripts.Script):
threshold_percentile = gr.Slider(minimum=90.0, value=100.0, maximum=100.0, step=0.05, label='Top percentile of latents to clamp', elem_id='dynthres_threshold_percentile')
interpolate_phi = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Interpolate Phi", value=1.0, elem_id='dynthres_interpolate_phi')
with gr.Row():
mimic_mode = gr.Dropdown(VALID_MODES, value="Constant", label="Mimic Scale Scheduler", elem_id='dynthres_mimic_mode')
cfg_mode = gr.Dropdown(VALID_MODES, value="Constant", label="CFG Scale Scheduler", elem_id='dynthres_cfg_mode')
mimic_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="Mimic Scale Scheduler", elem_id='dynthres_mimic_mode')
cfg_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="CFG Scale Scheduler", elem_id='dynthres_cfg_mode')
mimic_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=False, label="Minimum value of the Mimic Scale Scheduler", elem_id='dynthres_mimic_scale_min')
cfg_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=False, label="Minimum value of the CFG Scale Scheduler", elem_id='dynthres_cfg_scale_min')
sched_val = gr.Slider(minimum=0.0, maximum=40.0, step=0.5, value=4.0, visible=False, label="Scheduler Value", info="Value unique to the scheduler mode - for Power Up/Down, this is the power. For Linear/Cosine Repeating, this is the number of repeats per image.", elem_id='dynthres_sched_val')
@ -236,7 +235,7 @@ def make_axis_options():
setattr(p, "dynthres_enabled", False)
def confirm_scheduler(p, xs):
for x in xs:
if x not in VALID_MODES:
if x not in dynthres_core.DynThresh.Modes:
raise RuntimeError(f"Unknown Scheduler: {x}")
extra_axis_options = [
xyz_grid.AxisOption("[DynThres] Mimic Scale", float, apply_mimic_scale),
@ -246,9 +245,9 @@ def make_axis_options():
xyz_grid.AxisOption("[DynThres] Variability Measure", str, xyz_grid.apply_field("dynthres_variability_measure"), choices=lambda:['STD', 'AD']),
xyz_grid.AxisOption("[DynThres] Interpolate Phi", float, xyz_grid.apply_field("dynthres_interpolate_phi")),
xyz_grid.AxisOption("[DynThres] Threshold Percentile", float, xyz_grid.apply_field("dynthres_threshold_percentile")),
xyz_grid.AxisOption("[DynThres] Mimic Scheduler", str, xyz_grid.apply_field("dynthres_mimic_mode"), confirm=confirm_scheduler, choices=lambda: VALID_MODES),
xyz_grid.AxisOption("[DynThres] Mimic Scheduler", str, xyz_grid.apply_field("dynthres_mimic_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
xyz_grid.AxisOption("[DynThres] Mimic minimum", float, xyz_grid.apply_field("dynthres_mimic_scale_min")),
xyz_grid.AxisOption("[DynThres] CFG Scheduler", str, xyz_grid.apply_field("dynthres_cfg_mode"), confirm=confirm_scheduler, choices=lambda: VALID_MODES),
xyz_grid.AxisOption("[DynThres] CFG Scheduler", str, xyz_grid.apply_field("dynthres_cfg_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
xyz_grid.AxisOption("[DynThres] CFG minimum", float, xyz_grid.apply_field("dynthres_cfg_scale_min")),
xyz_grid.AxisOption("[DynThres] Scheduler value", float, xyz_grid.apply_field("dynthres_scheduler_val"))
]