actually just fix every single casing mixup
parent
9fe47f4e7a
commit
83f619ab37
|
|
@ -63,13 +63,13 @@ class DynamicThresholdingSimpleComfyNode:
|
|||
dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, "CONSTANT", 0, "CONSTANT", 0, 0, 0, 999, False, "MEAN", "AD", 1)
|
||||
|
||||
def sampler_dyn_thrash(args):
|
||||
x_out = args["cond"]
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
time_step = args["timestep"]
|
||||
dynamic_thresh.step = 999 - time_step[0]
|
||||
|
||||
return dynamic_thresh.dynthresh(x_out, uncond, cond_scale, None)
|
||||
return dynamic_thresh.dynthresh(cond, uncond, cond_scale, None)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(sampler_dyn_thrash)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class DynThresh:
|
|||
self.variability_measure = variability_measure
|
||||
self.interpolate_phi = interpolate_phi
|
||||
|
||||
def interpretScale(self, scale, mode, min):
|
||||
def interpret_scale(self, scale, mode, min):
|
||||
scale -= min
|
||||
max = self.max_steps - 1
|
||||
frac = self.step / max
|
||||
|
|
@ -56,8 +56,8 @@ class DynThresh:
|
|||
return scale
|
||||
|
||||
def dynthresh(self, cond, uncond, cfg_scale, weights):
|
||||
mimic_scale = self.interpretScale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
|
||||
cfg_scale = self.interpretScale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
|
||||
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
|
||||
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
|
||||
# uncond shape is (batch, 4, height, width)
|
||||
conds_per_batch = cond.shape[0] / uncond.shape[0]
|
||||
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
|
||||
|
|
@ -116,13 +116,13 @@ class DynThresh:
|
|||
### Now add it back onto the averages to get into real scale again and return
|
||||
result = cfg_renormalized + cfg_means
|
||||
|
||||
actualRes = result.unflatten(2, mim_target.shape[2:])
|
||||
actual_res = result.unflatten(2, mim_target.shape[2:])
|
||||
|
||||
if self.interpolate_phi != 1.0:
|
||||
actualRes = actualRes * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
|
||||
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
|
||||
|
||||
if self.experiment_mode == 1:
|
||||
num = actualRes.cpu().numpy()
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
if num[0][0][y][x] > 1.0:
|
||||
|
|
@ -131,19 +131,19 @@ class DynThresh:
|
|||
num[0][1][y][x] *= 0.5
|
||||
if num[0][2][y][x] > 1.5:
|
||||
num[0][2][y][x] *= 0.5
|
||||
actualRes = torch.from_numpy(num).to(device=uncond.device)
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 2:
|
||||
num = actualRes.cpu().numpy()
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
overScale = False
|
||||
over_scale = False
|
||||
for z in range(0, 4):
|
||||
if abs(num[0][z][y][x]) > 1.5:
|
||||
overScale = True
|
||||
if overScale:
|
||||
over_scale = True
|
||||
if over_scale:
|
||||
for z in range(0, 4):
|
||||
num[0][z][y][x] *= 0.7
|
||||
actualRes = torch.from_numpy(num).to(device=uncond.device)
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 3:
|
||||
coefs = torch.tensor([
|
||||
# R G B W
|
||||
|
|
@ -152,16 +152,16 @@ class DynThresh:
|
|||
[-0.158, 0.189, 0.264, 0.0], # L3
|
||||
[-0.184, -0.271, -0.473, 1.0], # L4
|
||||
], device=uncond.device)
|
||||
resRGB = torch.einsum("laxy,ab -> lbxy", actualRes, coefs)
|
||||
maxR, maxG, maxB, maxW = resRGB[0][0].max(), resRGB[0][1].max(), resRGB[0][2].max(), resRGB[0][3].max()
|
||||
maxRGB = max(maxR, maxG, maxB)
|
||||
print(f"test max = r={maxR}, g={maxG}, b={maxB}, w={maxW}, rgb={maxRGB}")
|
||||
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
|
||||
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
|
||||
max_rgb = max(max_r, max_g, max_b)
|
||||
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
|
||||
if self.step / (self.max_steps - 1) > 0.2:
|
||||
if maxRGB < 2.0 and maxW < 3.0:
|
||||
resRGB /= maxRGB / 2.4
|
||||
if max_rgb < 2.0 and max_w < 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
else:
|
||||
if maxRGB > 2.4 and maxW > 3.0:
|
||||
resRGB /= maxRGB / 2.4
|
||||
actualRes = torch.einsum("laxy,ab -> lbxy", resRGB, coefs.inverse())
|
||||
if max_rgb > 2.4 and max_w > 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
|
||||
|
||||
return actualRes
|
||||
return actual_res
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ except Exception as e:
|
|||
# (It has hooks but not in useful locations)
|
||||
# I stripped the original comments for brevity.
|
||||
# Some never-used code (scheduler modes, noise modes, guidance modes) have been removed as well for brevity.
|
||||
# The actual impl comes down to just the last line in particular, and the `beforeSample` insert to track step count.
|
||||
# The actual impl comes down to just the last line in particular, and the `before_sample` insert to track step count.
|
||||
|
||||
class CustomUniPCSampler(uni_pc.sampler.UniPCSampler):
|
||||
def __init__(self, model, **kwargs):
|
||||
|
|
@ -50,16 +50,16 @@ class CustomUniPCSampler(uni_pc.sampler.UniPCSampler):
|
|||
img = x_T
|
||||
ns = uni_pc.uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
model_type = "v" if self.model.parameterization == "v" else "noise"
|
||||
model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dtData=self.main_class)
|
||||
model_fn = CustomUniPC_model_wrapper(lambda x, t, c: self.model.apply_model(x, t, c), ns, model_type=model_type, guidance_scale=unconditional_guidance_scale, dt_data=self.main_class)
|
||||
self.main_class.step = 0
|
||||
def beforeSample(x, t, cond, uncond):
|
||||
def before_sample(x, t, cond, uncond):
|
||||
self.main_class.step += 1
|
||||
return self.before_sample(x, t, cond, uncond)
|
||||
uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=beforeSample, after_sample=self.after_sample, after_update=self.after_update)
|
||||
uni_pc_inst = uni_pc.uni_pc.UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||
x = uni_pc_inst.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
return x.to(device), None
|
||||
|
||||
def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dtData=None):
|
||||
def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_kwargs={}, guidance_scale=1.0, dt_data=None):
|
||||
def expand_dims(v, dims):
|
||||
return v[(...,) + (None,)*(dims - 1)]
|
||||
def get_model_input_time(t_continuous):
|
||||
|
|
@ -107,5 +107,5 @@ def CustomUniPC_model_wrapper(model, noise_schedule, model_type="noise", model_k
|
|||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
#return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
return dtData.dynthresh(noise, noise_uncond, guidance_scale, None)
|
||||
return dt_data.dynthresh(noise, noise_uncond, guidance_scale, None)
|
||||
return model_fn
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ class Script(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def vis_change(isVis):
|
||||
return {"visible": isVis, "__type__": "update"}
|
||||
def vis_change(is_vis):
|
||||
return {"visible": is_vis, "__type__": "update"}
|
||||
# "Dynamic Thresholding (CFG Scale Fix)"
|
||||
dtrue = gr.Checkbox(value=True, visible=False)
|
||||
dfalse = gr.Checkbox(value=False, visible=False)
|
||||
|
|
@ -64,11 +64,11 @@ class Script(scripts.Script):
|
|||
separate_feature_channels = gr.Checkbox(value=True, label="Separate Feature Channels", elem_id='dynthres_separate_feature_channels')
|
||||
scaling_startpoint = gr.Radio(["ZERO", "MEAN"], value="MEAN", label="Scaling Startpoint")
|
||||
variability_measure = gr.Radio(["STD", "AD"], value="AD", label="Variability Measure")
|
||||
def shouldShowSchedulerValue(cfgMode, mimicMode):
|
||||
sched_vis = cfgMode in MODES_WITH_VALUE or mimicMode in MODES_WITH_VALUE
|
||||
return vis_change(sched_vis), vis_change(mimicMode != "Constant"), vis_change(cfgMode != "Constant")
|
||||
cfg_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
||||
mimic_mode.change(shouldShowSchedulerValue, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
||||
def should_show_scheduler_value(cfg_mode, mimic_mode):
|
||||
sched_vis = cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE
|
||||
return vis_change(sched_vis), vis_change(mimic_mode != "Constant"), vis_change(cfg_mode != "Constant")
|
||||
cfg_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
||||
mimic_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
||||
enabled.change(
|
||||
_js="dynthres_update_enabled",
|
||||
fn=lambda x, y: {"visible": x, "__type__": "update"},
|
||||
|
|
@ -143,19 +143,19 @@ class Script(scripts.Script):
|
|||
|
||||
# Make a placeholder sampler
|
||||
sampler = sd_samplers.all_samplers_map[orig_sampler_name]
|
||||
dtData = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
|
||||
dt_data = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
|
||||
if orig_sampler_name == "UniPC":
|
||||
def uniPCConstructor(model):
|
||||
return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dtData)
|
||||
newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, uniPCConstructor, sampler.aliases, sampler.options)
|
||||
def unipc_constructor(model):
|
||||
return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dt_data)
|
||||
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, unipc_constructor, sampler.aliases, sampler.options)
|
||||
else:
|
||||
def newConstructor(model):
|
||||
def new_constructor(model):
|
||||
result = sampler.constructor(model)
|
||||
cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dtData)
|
||||
cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dt_data)
|
||||
result.model_wrap_cfg = cfg
|
||||
return result
|
||||
newSampler = sd_samplers_common.SamplerData(fixed_sampler_name, newConstructor, sampler.aliases, sampler.options)
|
||||
return fixed_sampler_name, newSampler
|
||||
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, new_constructor, sampler.aliases, sampler.options)
|
||||
return fixed_sampler_name, new_sampler
|
||||
|
||||
# Apply for usage
|
||||
p.orig_sampler_name = orig_sampler_name
|
||||
|
|
@ -163,14 +163,14 @@ class Script(scripts.Script):
|
|||
p.fixed_samplers = []
|
||||
|
||||
if orig_latent_sampler_name:
|
||||
latent_sampler_name, latentSampler = make_sampler(orig_latent_sampler_name)
|
||||
sd_samplers.all_samplers_map[latent_sampler_name] = latentSampler
|
||||
latent_sampler_name, latent_sampler = make_sampler(orig_latent_sampler_name)
|
||||
sd_samplers.all_samplers_map[latent_sampler_name] = latent_sampler
|
||||
p.fixed_samplers.append(latent_sampler_name)
|
||||
p.latent_sampler = latent_sampler_name
|
||||
|
||||
if orig_sampler_name != orig_latent_sampler_name:
|
||||
p.sampler_name, newSampler = make_sampler(orig_sampler_name)
|
||||
sd_samplers.all_samplers_map[p.sampler_name] = newSampler
|
||||
p.sampler_name, new_sampler = make_sampler(orig_sampler_name)
|
||||
sd_samplers.all_samplers_map[p.sampler_name] = new_sampler
|
||||
p.fixed_samplers.append(p.sampler_name)
|
||||
else:
|
||||
p.sampler_name = p.latent_sampler
|
||||
|
|
@ -193,16 +193,16 @@ class Script(scripts.Script):
|
|||
######################### CompVis Implementation logic #########################
|
||||
|
||||
class CustomVanillaSDSampler(sd_samplers_compvis.VanillaStableDiffusionSampler):
|
||||
def __init__(self, constructor, sd_model, dtData):
|
||||
def __init__(self, constructor, sd_model, dt_data):
|
||||
super().__init__(constructor, sd_model)
|
||||
self.sampler.main_class = dtData
|
||||
self.sampler.main_class = dt_data
|
||||
|
||||
######################### K-Diffusion Implementation logic #########################
|
||||
|
||||
class CustomCFGDenoiser(cfgdenoisekdiff):
|
||||
def __init__(self, model, dtData):
|
||||
def __init__(self, model, dt_data):
|
||||
super().__init__(model)
|
||||
self.main_class = dtData
|
||||
self.main_class = dt_data
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
if isinstance(uncond, dict) and 'crossattn' in uncond:
|
||||
|
|
@ -258,11 +258,11 @@ def make_axis_options():
|
|||
if not any("[DynThres]" in x.label for x in xyz_grid.axis_options):
|
||||
xyz_grid.axis_options.extend(extra_axis_options)
|
||||
|
||||
def callbackBeforeUi():
|
||||
def callback_before_ui():
|
||||
try:
|
||||
make_axis_options()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Failed to add support for X/Y/Z Plot Script because: {e}")
|
||||
|
||||
script_callbacks.on_before_ui(callbackBeforeUi)
|
||||
script_callbacks.on_before_ui(callback_before_ui)
|
||||
|
|
|
|||
Loading…
Reference in New Issue