fix for SDXL

now supports SDXL
main
zahand 2024-03-25 20:10:10 +03:30
parent c83ef85d03
commit f5552f525f
3 changed files with 28 additions and 19 deletions

View File

@ -39,11 +39,7 @@ Or manually clone this repo into your extensions folder:
![Screenshot of the slider provided by the extension in UI](/assets/screenshot.png "Does what it says on the box.") ![Screenshot of the slider provided by the extension in UI](/assets/screenshot.png "Does what it says on the box.")
After installing, you can locate the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs. After installing, you can find the new parameter "Negative Prompt Weight" in the extentions area of txt2img and img2img tabs.
## Limitations
Doesn't work with SDXL yet.
## More Comparisons and Stuff ## More Comparisons and Stuff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

After

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -4,7 +4,8 @@ import gradio as gr
import modules.scripts as scripts import modules.scripts as scripts
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser, remove_current_script_callbacks from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
from modules.prompt_parser import SdConditioning
class Script(scripts.Script): class Script(scripts.Script):
@ -65,7 +66,7 @@ class Script(scripts.Script):
# print('NPW callback removed') # print('NPW callback removed')
if self.weight != 1.0: if self.weight != 1.0:
self.empty_uncond = self.make_empty_uncond() self.empty_uncond = self.make_empty_uncond(p.width, p.height)
on_cfg_denoiser(self.denoiser_callback) on_cfg_denoiser(self.denoiser_callback)
# print('NPW callback added') # print('NPW callback added')
self.callbacks_added = True self.callbacks_added = True
@ -79,25 +80,37 @@ class Script(scripts.Script):
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
if hasattr(self, 'callbacks_added'): if hasattr(self, 'callbacks_added'):
remove_current_script_callbacks() remove_current_script_callbacks()
delattr(self, 'callbacks_added')
# print('NPW callback removed in post') # print('NPW callback removed in post')
def denoiser_callback(self, params): def denoiser_callback(self, params):
uncond = params.text_uncond def concat_and_lerp(empty, tensor, weight):
if tensor.shape[1] > empty.shape[1]:
if uncond.shape[1] > self.empty_uncond.shape[1]: num_concatenations = tensor.shape[1] // empty.shape[1]
num_concatenations = uncond.shape[1] // self.empty_uncond.shape[1] empty_concat = torch.cat([empty] * num_concatenations, dim=1)
empty_uncond_concat = torch.cat([self.empty_uncond] * num_concatenations, dim=1) if tensor.shape[1] == empty_concat.shape[1] + 1:
if uncond.shape[1] == empty_uncond_concat.shape[1] + 1: # assuming it's controlnet's marks(?)
# assuming it's controlnet's marks! empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1)
empty_uncond_concat = torch.cat([uncond[:, :1, :], empty_uncond_concat], dim=1) new_tensor = torch.lerp(empty_concat, tensor, weight)
new_uncond = torch.lerp(empty_uncond_concat, uncond, self.weight)
else: else:
new_uncond = torch.lerp(self.empty_uncond, uncond, self.weight) new_tensor = torch.lerp(empty, tensor, weight)
return new_tensor
params.text_uncond = new_uncond uncond = params.text_uncond
empty_uncond = self.empty_uncond
is_dict = isinstance(uncond, dict)
def make_empty_uncond(self): if is_dict:
empty_uncond = shared.sd_model.get_learned_conditioning([""]) uncond, cross = uncond['vector'], uncond['crossattn']
empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn']
params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight)
params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight)
else:
params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight)
def make_empty_uncond(self, w, h):
prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h)
empty_uncond = shared.sd_model.get_learned_conditioning(prompt)
return empty_uncond return empty_uncond
def print_warning(self, value): def print_warning(self, value):