fix for SDXL

now supports SDXL
main
zahand 2024-03-25 20:10:10 +03:30
parent c83ef85d03
commit f5552f525f
3 changed files with 28 additions and 19 deletions

View File

@ -39,11 +39,7 @@ Or manually clone this repo into your extensions folder:
![Screenshot of the slider provided by the extension in UI](/assets/screenshot.png "Does what it says on the box.")
After installing, you can locate the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs.
## Limitations
Doesn't work with SDXL yet.
After installing, you can find the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs.
## More Comparisons and Stuff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

After

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -4,7 +4,8 @@ import gradio as gr
import modules.scripts as scripts
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser, remove_current_script_callbacks
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
from modules.prompt_parser import SdConditioning
class Script(scripts.Script):
@ -65,7 +66,7 @@ class Script(scripts.Script):
# print('NPW callback removed')
if self.weight != 1.0:
self.empty_uncond = self.make_empty_uncond()
self.empty_uncond = self.make_empty_uncond(p.width, p.height)
on_cfg_denoiser(self.denoiser_callback)
# print('NPW callback added')
self.callbacks_added = True
@ -79,25 +80,37 @@ class Script(scripts.Script):
def postprocess(self, p, processed, *args):
if hasattr(self, 'callbacks_added'):
remove_current_script_callbacks()
delattr(self, 'callbacks_added')
# print('NPW callback removed in post')
def denoiser_callback(self, params):
def concat_and_lerp(empty, tensor, weight):
if tensor.shape[1] > empty.shape[1]:
num_concatenations = tensor.shape[1] // empty.shape[1]
empty_concat = torch.cat([empty] * num_concatenations, dim=1)
if tensor.shape[1] == empty_concat.shape[1] + 1:
# assuming it's controlnet's marks(?)
empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1)
new_tensor = torch.lerp(empty_concat, tensor, weight)
else:
new_tensor = torch.lerp(empty, tensor, weight)
return new_tensor
uncond = params.text_uncond
empty_uncond = self.empty_uncond
is_dict = isinstance(uncond, dict)
if uncond.shape[1] > self.empty_uncond.shape[1]:
num_concatenations = uncond.shape[1] // self.empty_uncond.shape[1]
empty_uncond_concat = torch.cat([self.empty_uncond] * num_concatenations, dim=1)
if uncond.shape[1] == empty_uncond_concat.shape[1] + 1:
# assuming it's controlnet's marks!
empty_uncond_concat = torch.cat([uncond[:, :1, :], empty_uncond_concat], dim=1)
new_uncond = torch.lerp(empty_uncond_concat, uncond, self.weight)
if is_dict:
uncond, cross = uncond['vector'], uncond['crossattn']
empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn']
params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight)
params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight)
else:
new_uncond = torch.lerp(self.empty_uncond, uncond, self.weight)
params.text_uncond = new_uncond
params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight)
def make_empty_uncond(self):
empty_uncond = shared.sd_model.get_learned_conditioning([""])
def make_empty_uncond(self, w, h):
prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h)
empty_uncond = shared.sd_model.get_learned_conditioning(prompt)
return empty_uncond
def print_warning(self, value):