diff --git a/README.md b/README.md index af6e5b1..9f6c94b 100644 --- a/README.md +++ b/README.md @@ -39,11 +39,7 @@ Or manually clone this repo into your extensions folder: ![Screenshot of the slider provided by the extension in UI](/assets/screenshot.png "Does what it says on the box.") -After installing, you can locate the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs. - -## Limitations - -Doesn't work with SDXL yet. +After installing, you can find the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs. ## More Comparisons and Stuff diff --git a/assets/screenshot.png b/assets/screenshot.png index 68f20fc..85f1055 100644 Binary files a/assets/screenshot.png and b/assets/screenshot.png differ diff --git a/scripts/npw.py b/scripts/npw.py index d388536..19c315a 100644 --- a/scripts/npw.py +++ b/scripts/npw.py @@ -4,7 +4,8 @@ import gradio as gr import modules.scripts as scripts import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser, remove_current_script_callbacks +from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks +from modules.prompt_parser import SdConditioning class Script(scripts.Script): @@ -65,7 +66,7 @@ class Script(scripts.Script): # print('NPW callback removed') if self.weight != 1.0: - self.empty_uncond = self.make_empty_uncond() + self.empty_uncond = self.make_empty_uncond(p.width, p.height) on_cfg_denoiser(self.denoiser_callback) # print('NPW callback added') self.callbacks_added = True @@ -79,25 +80,37 @@ class Script(scripts.Script): def postprocess(self, p, processed, *args): if hasattr(self, 'callbacks_added'): remove_current_script_callbacks() + delattr(self, 'callbacks_added') # print('NPW callback removed in post') def denoiser_callback(self, params): + def concat_and_lerp(empty, tensor, weight): + if tensor.shape[1] > empty.shape[1]: + num_concatenations = tensor.shape[1] // empty.shape[1] + empty_concat = torch.cat([empty] * num_concatenations, dim=1) + if tensor.shape[1] == empty_concat.shape[1] + 1: + # assuming it's controlnet's marks(?) + empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1) + new_tensor = torch.lerp(empty_concat, tensor, weight) + else: + new_tensor = torch.lerp(empty, tensor, weight) + return new_tensor + uncond = params.text_uncond + empty_uncond = self.empty_uncond + is_dict = isinstance(uncond, dict) - if uncond.shape[1] > self.empty_uncond.shape[1]: - num_concatenations = uncond.shape[1] // self.empty_uncond.shape[1] - empty_uncond_concat = torch.cat([self.empty_uncond] * num_concatenations, dim=1) - if uncond.shape[1] == empty_uncond_concat.shape[1] + 1: - # assuming it's controlnet's marks! - empty_uncond_concat = torch.cat([uncond[:, :1, :], empty_uncond_concat], dim=1) - new_uncond = torch.lerp(empty_uncond_concat, uncond, self.weight) + if is_dict: + uncond, cross = uncond['vector'], uncond['crossattn'] + empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn'] + params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight) + params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight) else: - new_uncond = torch.lerp(self.empty_uncond, uncond, self.weight) - - params.text_uncond = new_uncond + params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight) - def make_empty_uncond(self): - empty_uncond = shared.sd_model.get_learned_conditioning([""]) + def make_empty_uncond(self, w, h): + prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h) + empty_uncond = shared.sd_model.get_learned_conditioning(prompt) return empty_uncond def print_warning(self, value):