diff --git a/scripts/npw.py b/scripts/npw.py index 16b6b5f..67066c2 100644 --- a/scripts/npw.py +++ b/scripts/npw.py @@ -56,6 +56,8 @@ class Script(scripts.Script): weight = getattr(p, 'NPW_weight', weight) if weight != 1 : self.print_warning(weight) + self.width = p.width + self.height = p.height self.weight = weight self.empty_uncond = None @@ -66,7 +68,7 @@ class Script(scripts.Script): # print('NPW callback removed') if self.weight != 1.0: - self.empty_uncond = self.make_empty_uncond(p.width, p.height) + self.empty_uncond = self.make_empty_uncond(self.width, self.height) on_cfg_denoiser(self.denoiser_callback) # print('NPW callback added') self.callbacks_added = True @@ -98,9 +100,11 @@ class Script(scripts.Script): new_tensor = torch.lerp(empty, tensor, weight) return new_tensor - uncond = params.text_uncond - empty_uncond = self.empty_uncond + uncond = params.text_uncond is_dict = isinstance(uncond, dict) + if type(self.empty_uncond) != type(uncond): + self.empty_uncond = self.make_empty_uncond(self.width, self.height) + empty_uncond = self.empty_uncond if is_dict: uncond, cross = uncond['vector'], uncond['crossattn']