diff --git a/dynthres_comfyui.py b/dynthres_comfyui.py index 2e4d9f8..320d798 100644 --- a/dynthres_comfyui.py +++ b/dynthres_comfyui.py @@ -63,13 +63,13 @@ class DynamicThresholdingSimpleComfyNode: dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, "CONSTANT", 0, "CONSTANT", 0, 0, 0, 999, False, "MEAN", "AD", 1) def sampler_dyn_thrash(args): - x_out = args["cond"] + cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] time_step = args["timestep"] dynamic_thresh.step = 999 - time_step[0] - return dynamic_thresh.dynthresh(x_out, uncond, cond_scale, None) + return dynamic_thresh.dynthresh(cond, uncond, cond_scale, None) m = model.clone() m.set_model_sampler_cfg_function(sampler_dyn_thrash) diff --git a/dynthres_core.py b/dynthres_core.py index 11b951c..292c204 100644 --- a/dynthres_core.py +++ b/dynthres_core.py @@ -23,7 +23,7 @@ class DynThresh: self.variability_measure = variability_measure self.interpolate_phi = interpolate_phi - def interpretScale(self, scale, mode, min): + def interpret_scale(self, scale, mode, min): scale -= min max = self.max_steps - 1 frac = self.step / max @@ -56,8 +56,8 @@ class DynThresh: return scale def dynthresh(self, cond, uncond, cfg_scale, weights): - mimic_scale = self.interpretScale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min) - cfg_scale = self.interpretScale(cfg_scale, self.cfg_mode, self.cfg_scale_min) + mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min) + cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min) # uncond shape is (batch, 4, height, width) conds_per_batch = cond.shape[0] / uncond.shape[0] assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches" @@ -116,13 +116,13 @@ class DynThresh: ### Now add it back onto the averages to get into real scale again and return result = cfg_renormalized + cfg_means - actualRes = result.unflatten(2, mim_target.shape[2:]) + actual_res = result.unflatten(2, mim_target.shape[2:]) if self.interpolate_phi != 1.0: - actualRes = actualRes * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi) + actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi) if self.experiment_mode == 1: - num = actualRes.cpu().numpy() + num = actual_res.cpu().numpy() for y in range(0, 64): for x in range (0, 64): if num[0][0][y][x] > 1.0: @@ -131,19 +131,19 @@ class DynThresh: num[0][1][y][x] *= 0.5 if num[0][2][y][x] > 1.5: num[0][2][y][x] *= 0.5 - actualRes = torch.from_numpy(num).to(device=uncond.device) + actual_res = torch.from_numpy(num).to(device=uncond.device) elif self.experiment_mode == 2: - num = actualRes.cpu().numpy() + num = actual_res.cpu().numpy() for y in range(0, 64): for x in range (0, 64): - overScale = False + over_scale = False for z in range(0, 4): if abs(num[0][z][y][x]) > 1.5: - overScale = True - if overScale: + over_scale = True + if over_scale: for z in range(0, 4): num[0][z][y][x] *= 0.7 - actualRes = torch.from_numpy(num).to(device=uncond.device) + actual_res = torch.from_numpy(num).to(device=uncond.device) elif self.experiment_mode == 3: coefs = torch.tensor([ # R G B W @@ -152,16 +152,16 @@ class DynThresh: [-0.158, 0.189, 0.264, 0.0], # L3 [-0.184, -0.271, -0.473, 1.0], # L4 ], device=uncond.device) - resRGB = torch.einsum("laxy,ab -> lbxy", actualRes, coefs) - maxR, maxG, maxB, maxW = resRGB[0][0].max(), resRGB[0][1].max(), resRGB[0][2].max(), resRGB[0][3].max() - maxRGB = max(maxR, maxG, maxB) - print(f"test max = r={maxR}, g={maxG}, b={maxB}, w={maxW}, rgb={maxRGB}") + res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs) + max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max() + max_rgb = max(max_r, max_g, max_b) + print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}") if self.step / (self.max_steps - 1) > 0.2: - if maxRGB < 2.0 and maxW < 3.0: - resRGB /= maxRGB / 2.4 + if max_rgb < 2.0 and max_w < 3.0: + res_rgb /= max_rgb / 2.4 else: - if maxRGB > 2.4 and maxW > 3.0: - resRGB /= maxRGB / 2.4 - actualRes = torch.einsum("laxy,ab -> lbxy", resRGB, coefs.inverse()) + if max_rgb > 2.4 and max_w > 3.0: + res_rgb /= max_rgb / 2.4 + actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse()) - return actualRes + return actual_res diff --git a/dynthres_unipc.py b/dynthres_unipc.py index 30ffdb9..cb33051 100644 --- a/dynthres_unipc.py +++ b/dynthres_unipc.py @@ -16,7 +16,7 @@ except Exception as e: # (It has hooks but not in useful locations) # I stripped the original comments for brevity. # Some never-used code (scheduler modes, noise modes, guidance modes) have been removed as well for brevity. -# The actual impl comes down to just the last line in particular, and the `beforeSample` insert to track step count. +# The actual impl comes down to just the last line in particular, and the `before_sample` insert to track step count. class CustomUniPCSampler(uni_pc.sampler.UniPCSampler): def __init__(self, model, **kwargs): @@ -50,16 +50,16 @@ class CustomUniPCSampler(uni_pc.sampler.UniPCSampler): img = x_T ns = uni_pc.uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) model_type = "v" if self.model.parameterization == "v" else "noise" - model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dtData=self.main_class) + model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dt_data=self.main_class) self.main_class.step = 0 - def beforeSample(x, t, cond, uncond): + def before_sample(x, t, cond, uncond): self.main_class.step += 1 return self.before_sample(x, t, cond, uncond) - uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=beforeSample, after_sample=self.after_sample, after_update=self.after_update) + uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc_inst.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None -def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dtData=None): +def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dt_data=None): def expand_dims(v, dims): return v[(...,) + (None,)*(dims - 1)] def get_model_input_time(t_continuous): @@ -107,5 +107,5 @@ def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_k c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) #return noise_uncond + guidance_scale * (noise - noise_uncond) - return dtData.dynthresh(noise, noise_uncond, guidance_scale, None) + return dt_data.dynthresh(noise, noise_uncond, guidance_scale, None) return model_fn diff --git a/scripts/dynamic_thresholding.py b/scripts/dynamic_thresholding.py index 1068976..f3f98e4 100644 --- a/scripts/dynamic_thresholding.py +++ b/scripts/dynamic_thresholding.py @@ -39,8 +39,8 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - def vis_change(isVis): - return {"visible": isVis, "__type__": "update"} + def vis_change(is_vis): + return {"visible": is_vis, "__type__": "update"} # "Dynamic Thresholding (CFG Scale Fix)" dtrue = gr.Checkbox(value=True, visible=False) dfalse = gr.Checkbox(value=False, visible=False) @@ -64,11 +64,11 @@ class Script(scripts.Script): separate_feature_channels = gr.Checkbox(value=True, label="Separate Feature Channels", elem_id='dynthres_separate_feature_channels') scaling_startpoint = gr.Radio(["ZERO", "MEAN"], value="MEAN", label="Scaling Startpoint") variability_measure = gr.Radio(["STD", "AD"], value="AD", label="Variability Measure") - def shouldShowSchedulerValue(cfgMode, mimicMode): - sched_vis = cfgMode in MODES_WITH_VALUE or mimicMode in MODES_WITH_VALUE - return vis_change(sched_vis), vis_change(mimicMode != "Constant"), vis_change(cfgMode != "Constant") - cfg_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min]) - mimic_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min]) + def should_show_scheduler_value(cfg_mode, mimic_mode): + sched_vis = cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE + return vis_change(sched_vis), vis_change(mimic_mode != "Constant"), vis_change(cfg_mode != "Constant") + cfg_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min]) + mimic_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min]) enabled.change( _js="dynthres_update_enabled", fn=lambda x, y: {"visible": x, "__type__": "update"}, @@ -143,19 +143,19 @@ class Script(scripts.Script): # Make a placeholder sampler sampler = sd_samplers.all_samplers_map[orig_sampler_name] - dtData = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi) + dt_data = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi) if orig_sampler_name == "UniPC": - def uniPCConstructor(model): - return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dtData) - newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, uniPCConstructor, sampler.aliases, sampler.options) + def unipc_constructor(model): + return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dt_data) + new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, unipc_constructor, sampler.aliases, sampler.options) else: - def newConstructor(model): + def new_constructor(model): result = sampler.constructor(model) - cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dtData) + cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dt_data) result.model_wrap_cfg = cfg return result - newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, newConstructor, sampler.aliases, sampler.options) - return fixed_sampler_name, newSampler + new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, new_constructor, sampler.aliases, sampler.options) + return fixed_sampler_name, new_sampler # Apply for usage p.orig_sampler_name = orig_sampler_name @@ -163,14 +163,14 @@ class Script(scripts.Script): p.fixed_samplers = [] if orig_latent_sampler_name: - latent_sampler_name, latentSampler = make_sampler(orig_latent_sampler_name) - sd_samplers.all_samplers_map[latent_sampler_name] = latentSampler + latent_sampler_name, latent_sampler = make_sampler(orig_latent_sampler_name) + sd_samplers.all_samplers_map[latent_sampler_name] = latent_sampler p.fixed_samplers.append(latent_sampler_name) p.latent_sampler = latent_sampler_name if orig_sampler_name != orig_latent_sampler_name: - p.sampler_name, newSampler = make_sampler(orig_sampler_name) - sd_samplers.all_samplers_map[p.sampler_name] = newSampler + p.sampler_name, new_sampler = make_sampler(orig_sampler_name) + sd_samplers.all_samplers_map[p.sampler_name] = new_sampler p.fixed_samplers.append(p.sampler_name) else: p.sampler_name = p.latent_sampler @@ -193,16 +193,16 @@ class Script(scripts.Script): ######################### CompVis Implementation logic ######################### class CustomVanillaSDSampler(sd_samplers_compvis.VanillaStableDiffusionSampler): - def __init__(self, constructor, sd_model, dtData): + def __init__(self, constructor, sd_model, dt_data): super().__init__(constructor, sd_model) - self.sampler.main_class = dtData + self.sampler.main_class = dt_data ######################### K-Diffusion Implementation logic ######################### class CustomCFGDenoiser(cfgdenoisekdiff): - def __init__(self, model, dtData): + def __init__(self, model, dt_data): super().__init__(model) - self.main_class = dtData + self.main_class = dt_data def combine_denoised(self, x_out, conds_list, uncond, cond_scale): if isinstance(uncond, dict) and 'crossattn' in uncond: @@ -258,11 +258,11 @@ def make_axis_options(): if not any("[DynThres]" in x.label for x in xyz_grid.axis_options): xyz_grid.axis_options.extend(extra_axis_options) -def callbackBeforeUi(): +def callback_before_ui(): try: make_axis_options() except Exception as e: traceback.print_exc() print(f"Failed to add support for X/Y/Z Plot Script because: {e}") -script_callbacks.on_before_ui(callbackBeforeUi) +script_callbacks.on_before_ui(callback_before_ui)