feat(prompt): add vision/reasoning symbols and vision toggle

Add visual capability indicators and user control for image input:
- Vision symbol (eye icon) for VL-capable models in dropdown
- Reasoning symbol (lightbulb) for thinking-capable models
- "Use vision" checkbox to control image input for enhancement
- Toggle dims and unchecks when non-VL model selected
- Vision auto-enables when switching to VL model

Also:
- Rename "Do sample" to "Use samplers" for consistency with VQA
- Add tooltips/hints for all prompt enhance UI elements
- Add CSS styling for dimmed vision toggle appearance
pull/4544/head
CalamitousFelicitousness 2025-12-10 21:00:08 +00:00
parent c1f9646eaa
commit a77173881f
2 changed files with 99 additions and 16 deletions

View File

@ -241,6 +241,21 @@
{"id":"","label":"Sort by","localized":"","reload":"","hint":"Sort by"},
{"id":"","label":"Nudenet","localized":"","reload":"","hint":"Flexible extension that can detect and obfustate nudity in images"},
{"id":"","label":"Prompt enhance","localized":"","reload":"","hint":"Extension that can use different LLMs to rewrite prompt for improved results"},
{"id":"","label":"Enhance now","localized":"","reload":"","hint":"Run prompt enhancement using the selected LLM model"},
{"id":"","label":"Apply to prompt","localized":"","reload":"","hint":"Automatically copy enhanced result to the prompt input box"},
{"id":"","label":"Auto enhance","localized":"","reload":"","hint":"Automatically enhance prompt before every image generation"},
{"id":"","label":"Use vision","localized":"","reload":"","hint":"Include input image when enhancing prompt.<br>Only available for vision-capable models, marked with \uf06e icon."},
{"id":"","label":"LLM model","localized":"","reload":"","hint":"Select the language model to use for prompt enhancement.<br>Models supporting vision are marked with \uf06e icon.<br>Models supporting thinking mode are marked with \uf0eb icon."},
{"id":"","label":"Model repo","localized":"","reload":"","hint":"HuggingFace repository ID for the model"},
{"id":"","label":"Model gguf","localized":"","reload":"","hint":"Optional GGUF quantized model repository on HuggingFace"},
{"id":"","label":"Model type","localized":"","reload":"","hint":"Optional GGUF model quantization type"},
{"id":"","label":"Model file","localized":"","reload":"","hint":"Optional specific GGUF model file inside the repository"},
{"id":"","label":"Load custom model","localized":"","reload":"","hint":"Load a custom model with the specified configuration"},
{"id":"","label":"NSFW allowed","localized":"","reload":"","hint":"Allow the model to generate adult content in enhanced prompts"},
{"id":"","label":"Prompt prefix","localized":"","reload":"","hint":"Text prepended to the enhanced prompt result.<br>Useful for adding consistent style tags or quality modifiers at the start."},
{"id":"","label":"Prompt suffix","localized":"","reload":"","hint":"Text appended to the enhanced prompt result.<br>Useful for adding quality tags like 'masterpiece, best quality' or artist names."},
{"id":"","label":"Enhanced prompt","localized":"","reload":"","hint":"The enhanced prompt output from the LLM"},
{"id":"","label":"Set prompt","localized":"","reload":"","hint":"Copy the enhanced prompt to the main prompt input"},
{"id":"","label":"Manage extensions","localized":"","reload":"","hint":"Manage extensions"},
{"id":"","label":"Manual install","localized":"","reload":"","hint":"Manually install extension"},
{"id":"","label":"Extension GIT repository URL","localized":"","reload":"","hint":"Specify extension repository URL on GitHub"},

View File

@ -9,7 +9,7 @@ import torch
import transformers
import gradio as gr
from PIL import Image
from modules import scripts_manager, shared, devices, errors, processing, sd_models, sd_modules, timer
from modules import scripts_manager, shared, devices, errors, processing, sd_models, sd_modules, timer, ui_symbols
debug_enabled = os.environ.get('SD_LLM_DEBUG', None) is not None
@ -28,6 +28,40 @@ def b64(image):
return encoded
def is_vision_model(model_name: str) -> bool:
"""Check if model supports vision/image input."""
if not model_name:
return False
return model_name in Options.img2img
def is_thinking_model(model_name: str) -> bool:
"""Check if model supports thinking/reasoning mode."""
if not model_name:
return False
return 'thinking' in model_name.lower()
def get_model_display_name(model_repo: str) -> str:
"""Generate display name with vision/reasoning symbols."""
symbols = []
if model_repo in Options.img2img:
symbols.append(ui_symbols.vision)
if is_thinking_model(model_repo):
symbols.append(ui_symbols.reasoning)
return f"{model_repo} {' '.join(symbols)}" if symbols else model_repo
def get_model_repo_from_display(display_name: str) -> str:
"""Strip symbols from display name to get repo."""
if not display_name:
return display_name
result = display_name
for symbol in [ui_symbols.vision, ui_symbols.reasoning]:
result = result.replace(symbol, '')
return result.strip()
@dataclass
class Options:
img2img = [
@ -104,6 +138,16 @@ class Options:
repetition_penalty: float = 1.2
thinking_mode: bool = False
@staticmethod
def get_model_choices():
"""Return list of display names for dropdown."""
return [get_model_display_name(repo) for repo in Options.models.keys()]
@staticmethod
def get_default_display():
"""Return display name for default model."""
return get_model_display_name(Options.default)
class Script(scripts_manager.Script):
prompt: gr.Textbox = None
@ -127,7 +171,8 @@ class Script(scripts_manager.Script):
self.llm = compile_torch(self.llm)
def load(self, name:str=None, model_repo:str=None, model_gguf:str=None, model_type:str=None, model_file:str=None):
name = name or self.options.default
# Strip symbols from display name if present
name = get_model_repo_from_display(name) if name else self.options.default
if self.busy:
shared.log.debug('Prompt enhance: busy')
return
@ -275,10 +320,15 @@ class Script(scripts_manager.Script):
filtered = re.sub(pattern, '', prompt)
return filtered, matches
def enhance(self, model: str=None, prompt:str=None, system:str=None, prefix:str=None, suffix:str=None, sample:bool=None, tokens:int=None, temperature:float=None, penalty:float=None, thinking:bool=False, seed:int=-1, image=None, nsfw:bool=None):
model = model or self.options.default
def enhance(self, model: str=None, prompt:str=None, system:str=None, prefix:str=None, suffix:str=None, sample:bool=None, tokens:int=None, temperature:float=None, penalty:float=None, thinking:bool=False, seed:int=-1, image=None, nsfw:bool=None, use_vision:bool=True):
# Strip symbols from model name if present
model = get_model_repo_from_display(model) if model else self.options.default
prompt = prompt or (self.prompt.value if self.prompt else "") # Check if self.prompt is None
image = image or self.image
# Handle vision toggle - if disabled or non-VL model, don't use image
if use_vision and is_vision_model(model):
image = image or self.image
else:
image = None
prefix = prefix or ''
suffix = suffix or ''
tokens = tokens or self.options.max_tokens
@ -452,7 +502,7 @@ class Script(scripts_manager.Script):
return prompt # Return original full prompt on censorship
return response
def apply(self, prompt, image, apply_prompt, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode): # Added nsfw_mode
def apply(self, prompt, image, apply_prompt, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode, use_vision):
response = self.enhance(
prompt=prompt,
image=image,
@ -465,30 +515,45 @@ class Script(scripts_manager.Script):
temperature=temperature,
penalty=repetition_penalty,
thinking=thinking_mode,
nsfw=nsfw_mode # Pass nsfw_mode here
nsfw=nsfw_mode,
use_vision=use_vision,
)
if apply_prompt:
return [response, response]
return [response, gr.update()]
def get_custom(self, name):
model_repo = self.options.models.get(name, {}).get('repo', None) or name
model_gguf = self.options.models.get(name, {}).get('gguf', None)
model_type = self.options.models.get(name, {}).get('type', None)
model_file = self.options.models.get(name, {}).get('file', None)
# Strip symbols from display name to get repo
repo_name = get_model_repo_from_display(name)
model_repo = self.options.models.get(repo_name, {}).get('repo', None) or repo_name
model_gguf = self.options.models.get(repo_name, {}).get('gguf', None)
model_type = self.options.models.get(repo_name, {}).get('type', None)
model_file = self.options.models.get(repo_name, {}).get('file', None)
return [model_repo, model_gguf, model_type, model_file]
def update_vision_toggle(self, model_name):
"""Update vision toggle interactivity and value based on model selection."""
repo_name = get_model_repo_from_display(model_name)
is_vl = is_vision_model(repo_name)
# When non-VL model: disable and uncheck. When VL model: enable and check.
return gr.update(interactive=is_vl, value=is_vl)
def ui(self, _is_img2img):
with gr.Accordion('Prompt enhance', open=False, elem_id='prompt_enhance'):
gr.HTML('<style>#prompt_enhance_use_vision:has(input:disabled) { opacity: 0.5; }</style>')
with gr.Row():
apply_btn = gr.Button(value='Enhance now', elem_id='prompt_enhance_apply', variant='primary')
with gr.Row():
apply_prompt = gr.Checkbox(label='Apply to prompt', value=False)
apply_auto = gr.Checkbox(label='Auto enhance', value=False)
with gr.Row():
# Set initial state based on whether default model supports vision
default_is_vl = is_vision_model(Options.default)
use_vision = gr.Checkbox(label='Use vision', value=default_is_vl, interactive=default_is_vl, elem_id='prompt_enhance_use_vision')
gr.HTML('<br>')
with gr.Group():
with gr.Row():
llm_model = gr.Dropdown(label='LLM model', choices=list(self.options.models), value=self.options.default, interactive=True, allow_custom_value=True, elem_id='prompt_enhance_model')
llm_model = gr.Dropdown(label='LLM model', choices=Options.get_model_choices(), value=Options.get_default_display(), interactive=True, allow_custom_value=True, elem_id='prompt_enhance_model')
with gr.Row():
load_btn = gr.Button(value='Load model', elem_id='prompt_enhance_load', variant='secondary')
load_btn.click(fn=self.load, inputs=[llm_model], outputs=[])
@ -511,7 +576,7 @@ class Script(scripts_manager.Script):
with gr.Accordion('Options', open=False, elem_id='prompt_enhance_options'):
with gr.Row():
max_tokens = gr.Slider(label='Max tokens', value=self.options.max_tokens, minimum=10, maximum=1024, step=1, interactive=True)
do_sample = gr.Checkbox(label='Do sample', value=self.options.do_sample, interactive=True)
do_sample = gr.Checkbox(label='Use samplers', value=self.options.do_sample, interactive=True)
with gr.Row():
temperature = gr.Slider(label='Temperature', value=self.options.temperature, minimum=0.0, maximum=1.0, step=0.01, interactive=True)
repetition_penalty = gr.Slider(label='Repetition penalty', value=self.options.repetition_penalty, minimum=0.0, maximum=2.0, step=0.01, interactive=True)
@ -536,8 +601,10 @@ class Script(scripts_manager.Script):
copy_btn.click(fn=lambda x: x, inputs=[prompt_output], outputs=[self.prompt])
if self.image is None:
self.image = gr.Image(type='pil', interactive=False, visible=False, width=64, height=64) # dummy image
apply_btn.click(fn=self.apply, inputs=[self.prompt, self.image, apply_prompt, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode], outputs=[prompt_output, self.prompt])
return [self.prompt, self.image, apply_auto, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode]
# Update vision toggle interactivity when model changes
llm_model.change(fn=self.update_vision_toggle, inputs=[llm_model], outputs=[use_vision], show_progress=False)
apply_btn.click(fn=self.apply, inputs=[self.prompt, self.image, apply_prompt, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode, use_vision], outputs=[prompt_output, self.prompt])
return [self.prompt, self.image, apply_auto, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode, use_vision]
def after_component(self, component, **kwargs): # searching for actual ui prompt components
if getattr(component, 'elem_id', '') in ['txt2img_prompt', 'img2img_prompt', 'control_prompt', 'video_prompt']:
@ -548,7 +615,7 @@ class Script(scripts_manager.Script):
self.image.use_original = True
def before_process(self, p: processing.StableDiffusionProcessing, *args, **kwargs): # pylint: disable=unused-argument
_self_prompt, self_image, apply_auto, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode = args
_self_prompt, self_image, apply_auto, llm_model, prompt_system, prompt_prefix, prompt_suffix, max_tokens, do_sample, temperature, repetition_penalty, thinking_mode, nsfw_mode, use_vision = args
if not apply_auto and not p.enhance_prompt:
return
if shared.state.skipped or shared.state.interrupted:
@ -572,6 +639,7 @@ class Script(scripts_manager.Script):
penalty=repetition_penalty,
thinking=thinking_mode,
nsfw=nsfw_mode,
use_vision=use_vision,
)
timer.process.record('prompt')
p.extra_generation_params['LLM'] = llm_model