unipc warning

pull/7/head
ljleb 2023-05-20 19:12:49 -04:00
parent c72e0b3a7f
commit 67357e3525
6 changed files with 44 additions and 16 deletions

View File

@ -2,6 +2,8 @@ from lib_neutral_prompt import hijacker, global_state, prompt_parser
from modules import script_callbacks, sd_samplers
import functools
import torch
import sys
import textwrap
def combine_denoised_hijack(x_out, batch_cond_indices, noisy_uncond, cond_scale, original_function):
@ -77,7 +79,10 @@ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name, model, original_function):
sampler = original_function(name, model)
if not global_state.is_enabled:
if name in ('DDIM', 'PLMS', 'UniPC'):
if global_state.is_enabled:
warn_unsupported_sampler()
return sampler
sampler.model_wrap_cfg.combine_denoised = functools.partial(
@ -85,3 +90,15 @@ def create_sampler_hijack(name, model, original_function):
original_function=sampler.model_wrap_cfg.combine_denoised
)
return sampler
def warn_unsupported_sampler():
if not global_state.verbose:
return
print(textwrap.dedent('''
[sd-webui-neutral-prompt extension]
Neutral prompt relies on composition via AND, which the webui does not support when using any of the DDIM, PLMS and UniPC samplers
The sampler will NOT be patched
Falling back on original sampler implementation...
'''), file=sys.stderr)

View File

@ -4,3 +4,4 @@ from typing import List
is_enabled: bool = False
perp_profile: List[List[str]] = []
cfg_rescale: float = 0.
verbose: bool = True

View File

@ -1,5 +1,6 @@
import functools
class ModuleHijacker:
def __init__(self, module):
self.__module = module
@ -21,7 +22,6 @@ class ModuleHijacker:
self.__original_functions.clear()
@staticmethod
def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
if not hasattr(module, hijacker_attribute):

View File

@ -1,7 +1,6 @@
from lib_neutral_prompt import hijacker, global_state
from modules import script_callbacks, prompt_parser
from enum import Enum
import torch
import re

View File

@ -1,5 +1,5 @@
from lib_neutral_prompt import global_state, prompt_parser
from modules import script_callbacks
from modules import script_callbacks, shared
from typing import Dict
import gradio as gr
import dataclasses
@ -10,9 +10,8 @@ img2img_prompt_textbox = None
@dataclasses.dataclass
class GradioUserInterface:
class AccordionInterface:
def __post_init__(self):
self.is_enabled = gr.Checkbox(label='Enable', value=False)
self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0)
self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)')
self.neutral_cond_scale = gr.Slider(label='Neutral CFG', minimum=-3, maximum=3, value=-1)
@ -20,7 +19,6 @@ class GradioUserInterface:
def arrange_components(self, is_img2img: bool):
with gr.Accordion(label='Neutral Prompt', open=False):
self.is_enabled.render()
self.cfg_rescale.render()
with gr.Accordion(label='Prompt formatter', open=False):
@ -38,17 +36,26 @@ class GradioUserInterface:
def get_components(self):
return (
self.is_enabled,
self.cfg_rescale,
)
def unpack_processing_args(self, is_enabled: bool, cfg_rescale: float) -> Dict:
def unpack_processing_args(
self,
cfg_rescale: float,
) -> Dict:
return {
'is_enabled': is_enabled,
'cfg_rescale': cfg_rescale,
}
def on_ui_settings():
section = ('neutral_prompt', 'Neutral Prompt')
shared.opts.add_option('neutral_prompt_enabled', shared.OptionInfo(True, 'Enabled', section=section))
script_callbacks.on_ui_settings(on_ui_settings)
def on_after_component(component, **_kwargs):
if getattr(component, 'elem_id', None) == 'txt2img_prompt':
global txt2img_prompt_textbox

View File

@ -5,12 +5,12 @@ importlib.reload(hijacker)
importlib.reload(prompt_parser)
importlib.reload(cfg_denoiser)
importlib.reload(ui)
from modules import scripts, processing
from modules import scripts, processing, shared
class NeutralPromptScript(scripts.Script):
def __init__(self):
self.gui = ui.GradioUserInterface()
self.accordion_interface = ui.AccordionInterface()
def title(self) -> str:
return "Neutral Prompt"
@ -19,12 +19,16 @@ class NeutralPromptScript(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img: bool):
self.gui.arrange_components(is_img2img)
self.gui.connect_events(is_img2img)
return self.gui.get_components()
self.accordion_interface.arrange_components(is_img2img)
self.accordion_interface.connect_events(is_img2img)
return self.accordion_interface.get_components()
def process(self, p: processing.StableDiffusionProcessing, *args):
for k, v in self.gui.unpack_processing_args(*args).items():
if shared.state.job_no > 0:
return
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
for k, v in self.accordion_interface.unpack_processing_args(*args).items():
try:
getattr(global_state, k)
except AttributeError: