unipc warning
parent
c72e0b3a7f
commit
67357e3525
|
|
@ -2,6 +2,8 @@ from lib_neutral_prompt import hijacker, global_state, prompt_parser
|
|||
from modules import script_callbacks, sd_samplers
|
||||
import functools
|
||||
import torch
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
|
||||
def combine_denoised_hijack(x_out, batch_cond_indices, noisy_uncond, cond_scale, original_function):
|
||||
|
|
@ -77,7 +79,10 @@ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
|
|||
@sd_samplers_hijacker.hijack('create_sampler')
|
||||
def create_sampler_hijack(name, model, original_function):
|
||||
sampler = original_function(name, model)
|
||||
if not global_state.is_enabled:
|
||||
if name in ('DDIM', 'PLMS', 'UniPC'):
|
||||
if global_state.is_enabled:
|
||||
warn_unsupported_sampler()
|
||||
|
||||
return sampler
|
||||
|
||||
sampler.model_wrap_cfg.combine_denoised = functools.partial(
|
||||
|
|
@ -85,3 +90,15 @@ def create_sampler_hijack(name, model, original_function):
|
|||
original_function=sampler.model_wrap_cfg.combine_denoised
|
||||
)
|
||||
return sampler
|
||||
|
||||
|
||||
def warn_unsupported_sampler():
|
||||
if not global_state.verbose:
|
||||
return
|
||||
|
||||
print(textwrap.dedent('''
|
||||
[sd-webui-neutral-prompt extension]
|
||||
Neutral prompt relies on composition via AND, which the webui does not support when using any of the DDIM, PLMS and UniPC samplers
|
||||
The sampler will NOT be patched
|
||||
Falling back on original sampler implementation...
|
||||
'''), file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ from typing import List
|
|||
is_enabled: bool = False
|
||||
perp_profile: List[List[str]] = []
|
||||
cfg_rescale: float = 0.
|
||||
verbose: bool = True
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import functools
|
||||
|
||||
|
||||
class ModuleHijacker:
|
||||
def __init__(self, module):
|
||||
self.__module = module
|
||||
|
|
@ -21,7 +22,6 @@ class ModuleHijacker:
|
|||
|
||||
self.__original_functions.clear()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
|
||||
if not hasattr(module, hijacker_attribute):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from lib_neutral_prompt import hijacker, global_state
|
||||
from modules import script_callbacks, prompt_parser
|
||||
from enum import Enum
|
||||
import torch
|
||||
import re
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from lib_neutral_prompt import global_state, prompt_parser
|
||||
from modules import script_callbacks
|
||||
from modules import script_callbacks, shared
|
||||
from typing import Dict
|
||||
import gradio as gr
|
||||
import dataclasses
|
||||
|
|
@ -10,9 +10,8 @@ img2img_prompt_textbox = None
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GradioUserInterface:
|
||||
class AccordionInterface:
|
||||
def __post_init__(self):
|
||||
self.is_enabled = gr.Checkbox(label='Enable', value=False)
|
||||
self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0)
|
||||
self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)')
|
||||
self.neutral_cond_scale = gr.Slider(label='Neutral CFG', minimum=-3, maximum=3, value=-1)
|
||||
|
|
@ -20,7 +19,6 @@ class GradioUserInterface:
|
|||
|
||||
def arrange_components(self, is_img2img: bool):
|
||||
with gr.Accordion(label='Neutral Prompt', open=False):
|
||||
self.is_enabled.render()
|
||||
self.cfg_rescale.render()
|
||||
|
||||
with gr.Accordion(label='Prompt formatter', open=False):
|
||||
|
|
@ -38,17 +36,26 @@ class GradioUserInterface:
|
|||
|
||||
def get_components(self):
|
||||
return (
|
||||
self.is_enabled,
|
||||
self.cfg_rescale,
|
||||
)
|
||||
|
||||
def unpack_processing_args(self, is_enabled: bool, cfg_rescale: float) -> Dict:
|
||||
def unpack_processing_args(
|
||||
self,
|
||||
cfg_rescale: float,
|
||||
) -> Dict:
|
||||
return {
|
||||
'is_enabled': is_enabled,
|
||||
'cfg_rescale': cfg_rescale,
|
||||
}
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ('neutral_prompt', 'Neutral Prompt')
|
||||
shared.opts.add_option('neutral_prompt_enabled', shared.OptionInfo(True, 'Enabled', section=section))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
||||
|
||||
def on_after_component(component, **_kwargs):
|
||||
if getattr(component, 'elem_id', None) == 'txt2img_prompt':
|
||||
global txt2img_prompt_textbox
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ importlib.reload(hijacker)
|
|||
importlib.reload(prompt_parser)
|
||||
importlib.reload(cfg_denoiser)
|
||||
importlib.reload(ui)
|
||||
from modules import scripts, processing
|
||||
from modules import scripts, processing, shared
|
||||
|
||||
|
||||
class NeutralPromptScript(scripts.Script):
|
||||
def __init__(self):
|
||||
self.gui = ui.GradioUserInterface()
|
||||
self.accordion_interface = ui.AccordionInterface()
|
||||
|
||||
def title(self) -> str:
|
||||
return "Neutral Prompt"
|
||||
|
|
@ -19,12 +19,16 @@ class NeutralPromptScript(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img: bool):
|
||||
self.gui.arrange_components(is_img2img)
|
||||
self.gui.connect_events(is_img2img)
|
||||
return self.gui.get_components()
|
||||
self.accordion_interface.arrange_components(is_img2img)
|
||||
self.accordion_interface.connect_events(is_img2img)
|
||||
return self.accordion_interface.get_components()
|
||||
|
||||
def process(self, p: processing.StableDiffusionProcessing, *args):
|
||||
for k, v in self.gui.unpack_processing_args(*args).items():
|
||||
if shared.state.job_no > 0:
|
||||
return
|
||||
|
||||
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
||||
for k, v in self.accordion_interface.unpack_processing_args(*args).items():
|
||||
try:
|
||||
getattr(global_state, k)
|
||||
except AttributeError:
|
||||
|
|
|
|||
Loading…
Reference in New Issue