diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/scripts/dynamic_thresholding.py b/scripts/dynamic_thresholding.py index 8217531..6e61073 100644 --- a/scripts/dynamic_thresholding.py +++ b/scripts/dynamic_thresholding.py @@ -103,32 +103,26 @@ class CustomCFGDenoiser(sd_samplers.CFGDenoiser): def dynthresh(self, cond, uncond, cfgScale, conds_list): mimicScale = self.mimic_scale - match self.mimic_mode: - case "Constant": - pass - case "Linear Down": - mimicScale *= 1.0 - (self.step / self.maxSteps) - case "Cosine Down": - mimicScale *= 1.0 - math.cos(self.step / self.maxSteps) - case "Linear Up": - mimicScale *= self.step / self.maxSteps - pass - case "Cosine Up": - mimicScale *= math.cos(self.step / self.maxSteps) - pass - match self.cfg_mode: - case "Constant": - pass - case "Linear Down": - cfgScale *= 1.0 - (self.step / self.maxSteps) - case "Cosine Down": - cfgScale *= 1.0 - math.cos(self.step / self.maxSteps) - case "Linear Up": - cfgScale *= self.step / self.maxSteps - pass - case "Cosine Up": - cfgScale *= math.cos(self.step / self.maxSteps) - pass + if self.mimic_mode == "Constant": + pass + elif self.mimic_mode == "Linear Down": + mimicScale *= 1.0 - (self.step / self.maxSteps) + elif self.mimic_mode == "Cosine Down": + mimicScale *= 1.0 - math.cos(self.step / self.maxSteps) + elif self.mimic_mode == "Linear Up": + mimicScale *= self.step / self.maxSteps + elif self.mimic_mode == "Cosine Up": + mimicScale *= math.cos(self.step / self.maxSteps) + if self.cfg_mode == "Constant": + pass + elif self.cfg_mode == "Linear Down": + cfgScale *= 1.0 - (self.step / self.maxSteps) + elif self.cfg_mode == "Cosine Down": + cfgScale *= 1.0 - math.cos(self.step / self.maxSteps) + elif self.cfg_mode == "Linear Up": + cfgScale *= self.step / self.maxSteps + elif self.cfg_mode == "Cosine Up": + cfgScale *= math.cos(self.step / self.maxSteps) # uncond shape is (batch, 4, height, width) conds_per_batch = cond.shape[0] / uncond.shape[0] assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"