diff --git a/dynthres_core.py b/dynthres_core.py index 723a635..11b951c 100644 --- a/dynthres_core.py +++ b/dynthres_core.py @@ -55,9 +55,9 @@ class DynThresh: scale += min return scale - def dynthresh(self, cond, uncond, cfgScale, weights): - mimicScale = self.interpretScale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min) - cfgScale = self.interpretScale(cfgScale, self.cfg_mode, self.cfg_scale_min) + def dynthresh(self, cond, uncond, cfg_scale, weights): + mimic_scale = self.interpretScale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min) + cfg_scale = self.interpretScale(cfg_scale, self.cfg_mode, self.cfg_scale_min) # uncond shape is (batch, 4, height, width) conds_per_batch = cond.shape[0] / uncond.shape[0] assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches" @@ -70,8 +70,8 @@ class DynThresh: relative = diff.sum(1) ### Get the normal result for both mimic and normal scale - mim_target = uncond + relative * mimicScale - cfg_target = uncond + relative * cfgScale + mim_target = uncond + relative * mimic_scale + cfg_target = uncond + relative * cfg_scale ### If we weren't doing mimic scale, we'd just return cfg_target here ### Now recenter the values relative to their average rather than absolute, to allow scaling from average