actually just fix every single casing mixup

pull/87/head
Alex "mcmonkey" Goodwin 2023-10-29 07:34:22 -07:00
parent 9fe47f4e7a
commit 83f619ab37
4 changed files with 55 additions and 55 deletions

View File

@ -63,13 +63,13 @@ class DynamicThresholdingSimpleComfyNode:
dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, "CONSTANT", 0, "CONSTANT", 0, 0, 0, 999, False, "MEAN", "AD", 1)
def sampler_dyn_thrash(args):
x_out = args["cond"]
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
time_step = args["timestep"]
dynamic_thresh.step = 999 - time_step[0]
return dynamic_thresh.dynthresh(x_out, uncond, cond_scale, None)
return dynamic_thresh.dynthresh(cond, uncond, cond_scale, None)
m = model.clone()
m.set_model_sampler_cfg_function(sampler_dyn_thrash)

View File

@ -23,7 +23,7 @@ class DynThresh:
self.variability_measure = variability_measure
self.interpolate_phi = interpolate_phi
def interpretScale(self, scale, mode, min):
def interpret_scale(self, scale, mode, min):
scale -= min
max = self.max_steps - 1
frac = self.step / max
@ -56,8 +56,8 @@ class DynThresh:
return scale
def dynthresh(self, cond, uncond, cfg_scale, weights):
mimic_scale = self.interpretScale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
cfg_scale = self.interpretScale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
# uncond shape is (batch, 4, height, width)
conds_per_batch = cond.shape[0] / uncond.shape[0]
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
@ -116,13 +116,13 @@ class DynThresh:
### Now add it back onto the averages to get into real scale again and return
result = cfg_renormalized + cfg_means
actualRes = result.unflatten(2, mim_target.shape[2:])
actual_res = result.unflatten(2, mim_target.shape[2:])
if self.interpolate_phi != 1.0:
actualRes = actualRes * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
if self.experiment_mode == 1:
num = actualRes.cpu().numpy()
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
if num[0][0][y][x] > 1.0:
@ -131,19 +131,19 @@ class DynThresh:
num[0][1][y][x] *= 0.5
if num[0][2][y][x] > 1.5:
num[0][2][y][x] *= 0.5
actualRes = torch.from_numpy(num).to(device=uncond.device)
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 2:
num = actualRes.cpu().numpy()
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
overScale = False
over_scale = False
for z in range(0, 4):
if abs(num[0][z][y][x]) > 1.5:
overScale = True
if overScale:
over_scale = True
if over_scale:
for z in range(0, 4):
num[0][z][y][x] *= 0.7
actualRes = torch.from_numpy(num).to(device=uncond.device)
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 3:
coefs = torch.tensor([
# R G B W
@ -152,16 +152,16 @@ class DynThresh:
[-0.158, 0.189, 0.264, 0.0], # L3
[-0.184, -0.271, -0.473, 1.0], # L4
], device=uncond.device)
resRGB = torch.einsum("laxy,ab -> lbxy", actualRes, coefs)
maxR, maxG, maxB, maxW = resRGB[0][0].max(), resRGB[0][1].max(), resRGB[0][2].max(), resRGB[0][3].max()
maxRGB = max(maxR, maxG, maxB)
print(f"test max = r={maxR}, g={maxG}, b={maxB}, w={maxW}, rgb={maxRGB}")
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
max_rgb = max(max_r, max_g, max_b)
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
if self.step / (self.max_steps - 1) > 0.2:
if maxRGB < 2.0 and maxW < 3.0:
resRGB /= maxRGB / 2.4
if max_rgb < 2.0 and max_w < 3.0:
res_rgb /= max_rgb / 2.4
else:
if maxRGB > 2.4 and maxW > 3.0:
resRGB /= maxRGB / 2.4
actualRes = torch.einsum("laxy,ab -> lbxy", resRGB, coefs.inverse())
if max_rgb > 2.4 and max_w > 3.0:
res_rgb /= max_rgb / 2.4
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
return actualRes
return actual_res

View File

@ -16,7 +16,7 @@ except Exception as e:
# (It has hooks but not in useful locations)
# I stripped the original comments for brevity.
# Some never-used code (scheduler modes, noise modes, guidance modes) have been removed as well for brevity.
# The actual impl comes down to just the last line in particular, and the `beforeSample` insert to track step count.
# The actual impl comes down to just the last line in particular, and the `before_sample` insert to track step count.
class CustomUniPCSampler(uni_pc.sampler.UniPCSampler):
def __init__(self, model, **kwargs):
@ -50,16 +50,16 @@ class CustomUniPCSampler(uni_pc.sampler.UniPCSampler):
img = x_T
ns = uni_pc.uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dtData=self.main_class)
model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dt_data=self.main_class)
self.main_class.step = 0
def beforeSample(x, t, cond, uncond):
def before_sample(x, t, cond, uncond):
self.main_class.step += 1
return self.before_sample(x, t, cond, uncond)
uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=beforeSample, after_sample=self.after_sample, after_update=self.after_update)
uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc_inst.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x.to(device), None
def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dtData=None):
def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dt_data=None):
def expand_dims(v, dims):
return v[(...,) + (None,)*(dims - 1)]
def get_model_input_time(t_continuous):
@ -107,5 +107,5 @@ def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_k
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
#return noise_uncond + guidance_scale * (noise - noise_uncond)
return dtData.dynthresh(noise, noise_uncond, guidance_scale, None)
return dt_data.dynthresh(noise, noise_uncond, guidance_scale, None)
return model_fn

View File

@ -39,8 +39,8 @@ class Script(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img):
def vis_change(isVis):
return {"visible": isVis, "__type__": "update"}
def vis_change(is_vis):
return {"visible": is_vis, "__type__": "update"}
# "Dynamic Thresholding (CFG Scale Fix)"
dtrue = gr.Checkbox(value=True, visible=False)
dfalse = gr.Checkbox(value=False, visible=False)
@ -64,11 +64,11 @@ class Script(scripts.Script):
separate_feature_channels = gr.Checkbox(value=True, label="Separate Feature Channels", elem_id='dynthres_separate_feature_channels')
scaling_startpoint = gr.Radio(["ZERO", "MEAN"], value="MEAN", label="Scaling Startpoint")
variability_measure = gr.Radio(["STD", "AD"], value="AD", label="Variability Measure")
def shouldShowSchedulerValue(cfgMode, mimicMode):
sched_vis = cfgMode in MODES_WITH_VALUE or mimicMode in MODES_WITH_VALUE
return vis_change(sched_vis), vis_change(mimicMode != "Constant"), vis_change(cfgMode != "Constant")
cfg_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
mimic_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
def should_show_scheduler_value(cfg_mode, mimic_mode):
sched_vis = cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE
return vis_change(sched_vis), vis_change(mimic_mode != "Constant"), vis_change(cfg_mode != "Constant")
cfg_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
mimic_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
enabled.change(
_js="dynthres_update_enabled",
fn=lambda x, y: {"visible": x, "__type__": "update"},
@ -143,19 +143,19 @@ class Script(scripts.Script):
# Make a placeholder sampler
sampler = sd_samplers.all_samplers_map[orig_sampler_name]
dtData = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
dt_data = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
if orig_sampler_name == "UniPC":
def uniPCConstructor(model):
return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dtData)
newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, uniPCConstructor, sampler.aliases, sampler.options)
def unipc_constructor(model):
return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dt_data)
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, unipc_constructor, sampler.aliases, sampler.options)
else:
def newConstructor(model):
def new_constructor(model):
result = sampler.constructor(model)
cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dtData)
cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dt_data)
result.model_wrap_cfg = cfg
return result
newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, newConstructor, sampler.aliases, sampler.options)
return fixed_sampler_name, newSampler
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, new_constructor, sampler.aliases, sampler.options)
return fixed_sampler_name, new_sampler
# Apply for usage
p.orig_sampler_name = orig_sampler_name
@ -163,14 +163,14 @@ class Script(scripts.Script):
p.fixed_samplers = []
if orig_latent_sampler_name:
latent_sampler_name, latentSampler = make_sampler(orig_latent_sampler_name)
sd_samplers.all_samplers_map[latent_sampler_name] = latentSampler
latent_sampler_name, latent_sampler = make_sampler(orig_latent_sampler_name)
sd_samplers.all_samplers_map[latent_sampler_name] = latent_sampler
p.fixed_samplers.append(latent_sampler_name)
p.latent_sampler = latent_sampler_name
if orig_sampler_name != orig_latent_sampler_name:
p.sampler_name, newSampler = make_sampler(orig_sampler_name)
sd_samplers.all_samplers_map[p.sampler_name] = newSampler
p.sampler_name, new_sampler = make_sampler(orig_sampler_name)
sd_samplers.all_samplers_map[p.sampler_name] = new_sampler
p.fixed_samplers.append(p.sampler_name)
else:
p.sampler_name = p.latent_sampler
@ -193,16 +193,16 @@ class Script(scripts.Script):
######################### CompVis Implementation logic #########################
class CustomVanillaSDSampler(sd_samplers_compvis.VanillaStableDiffusionSampler):
def __init__(self, constructor, sd_model, dtData):
def __init__(self, constructor, sd_model, dt_data):
super().__init__(constructor, sd_model)
self.sampler.main_class = dtData
self.sampler.main_class = dt_data
######################### K-Diffusion Implementation logic #########################
class CustomCFGDenoiser(cfgdenoisekdiff):
def __init__(self, model, dtData):
def __init__(self, model, dt_data):
super().__init__(model)
self.main_class = dtData
self.main_class = dt_data
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
if isinstance(uncond, dict) and 'crossattn' in uncond:
@ -258,11 +258,11 @@ def make_axis_options():
if not any("[DynThres]" in x.label for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(extra_axis_options)
def callbackBeforeUi():
def callback_before_ui():
try:
make_axis_options()
except Exception as e:
traceback.print_exc()
print(f"Failed to add support for X/Y/Z Plot Script because: {e}")
script_callbacks.on_before_ui(callbackBeforeUi)
script_callbacks.on_before_ui(callback_before_ui)