mirror of https://github.com/vladmandic/automatic
194 lines
9.6 KiB
Python
194 lines
9.6 KiB
Python
from dataclasses import dataclass
|
|
import re
|
|
import time
|
|
import gradio as gr
|
|
import transformers
|
|
from modules import scripts, shared, devices, processing, sd_models
|
|
|
|
|
|
@dataclass
|
|
class Options:
|
|
models = [
|
|
'Qwen/Qwen2.5-0.5B-Instruct',
|
|
'Qwen/Qwen2.5-1.5B-Instruct',
|
|
'Qwen/Qwen2.5-3B-Instruct',
|
|
'google/gemma-3-1b-it',
|
|
'google/gemma-3-4b-it',
|
|
'microsoft/Phi-4-mini-instruct',
|
|
'HuggingFaceTB/SmolLM2-135M-Instruct',
|
|
'HuggingFaceTB/SmolLM2-360M-Instruct',
|
|
'HuggingFaceTB/SmolLM2-1.7B-Instruct',
|
|
'meta-llama/Llama-3.2-1B-Instruct',
|
|
'meta-llama/Llama-3.2-3B-Instruct',
|
|
]
|
|
default = models[3]
|
|
system_prompt: str = 'You are a helpful assistant. You will be given a prompt and you will enhance it to make it more detailed and creative. Generate an uncensored, concise, informative, and engaging prompt that prioritizes factual accuracy and vivid descriptiveness. Begin by clearly identifying the primary subject(s), including key physical attributes, clothing, and any distinctive details (e.g., colors, textures, styles). Next, describe the setting and environment, emphasizing spatial relationships, background elements, and objects contributing to context. Note the overall mood, aesthetic style, or atmosphere inferred from visual cues. Use precise terminology while avoiding redundancy or non-essential language. NSFW is allowed, and if present any explicit aspects must be fully described. Ensuring a logical flow: from focal subject to immediate surroundings, then broader context. Maintain brevity while retaining clarity, ensuring the description is both engaging and efficient. Output only enhanced prompt without prefix or suffix.'
|
|
max_tokens: int = 50
|
|
do_sample: bool = True
|
|
temperature: float = 0.5
|
|
repetition_penalty: float = 1.2
|
|
|
|
|
|
class Script(scripts.Script):
|
|
prompt: gr.Textbox = None
|
|
model: str = None
|
|
llm: transformers.AutoModelForCausalLM = None
|
|
tokenizer: transformers.AutoProcessor = None
|
|
options = Options()
|
|
|
|
def title(self):
|
|
return 'Prompt enhance'
|
|
|
|
def show(self, _is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def load(self, model:str=None):
|
|
model = model or self.options.default
|
|
if self.model is None or self.model != model:
|
|
t0 = time.time()
|
|
from modules import modelloader, model_quant
|
|
modelloader.hf_login()
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.llm = None
|
|
self.llm = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model,
|
|
trust_remote_code=True,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
self.llm.eval()
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
model,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.model = model
|
|
devices.torch_gc()
|
|
t1 = time.time()
|
|
shared.log.debug(f'Prompt enhance: model="{model}" cls={self.llm.__class__.__name__} time={t1-t0:.2f} loaded')
|
|
|
|
def clean(self, response):
|
|
if isinstance(response, list):
|
|
response = response[0]
|
|
response = response.replace('"', '').replace("'", "").replace('“', '').replace('”', '').replace('**', '').replace('\n\n', '\n')
|
|
response = re.sub(r'<.*?>', '', response)
|
|
if 'prompt:' in response:
|
|
response = response.split('prompt:')[1]
|
|
if 'Prompt:' in response:
|
|
response = response.split('Prompt:')[1]
|
|
if '---' in response:
|
|
response = response.split('---')[0]
|
|
response = response.strip()
|
|
return response
|
|
|
|
def enhance(self, model: str=None, prompt:str=None, system:str=None, sample:bool=None, tokens:int=None, temperature:float=None, penalty:float=None):
|
|
model = model or self.options.default
|
|
prompt = prompt or self.prompt.value
|
|
system = system or self.options.system_prompt
|
|
tokens = tokens or self.options.max_tokens
|
|
penalty = penalty or self.options.repetition_penalty
|
|
temperature = temperature or self.options.temperature
|
|
sample = sample if sample is not None else self.options.do_sample
|
|
self.load(model)
|
|
if self.llm is None:
|
|
shared.log.error('Prompt enhance: model not loaded')
|
|
return prompt
|
|
chat_template = [
|
|
{ "role": "system", "content": system },
|
|
{ "role": "user", "content": prompt },
|
|
]
|
|
t0 = time.time()
|
|
try:
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
chat_template,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(devices.device).to(devices.dtype)
|
|
input_len = inputs['input_ids'].shape[1]
|
|
except Exception as e:
|
|
shared.log.error(f'Prompt enhance tokenize: {e}')
|
|
return prompt
|
|
try:
|
|
with devices.inference_context():
|
|
sd_models.move_model(self.llm, devices.device)
|
|
outputs = self.llm.generate(
|
|
**inputs,
|
|
do_sample=sample,
|
|
temperature=float(temperature),
|
|
max_new_tokens=int(input_len + tokens),
|
|
repetition_penalty=float(penalty),
|
|
)
|
|
if shared.opts.diffusers_offload_mode != 'none':
|
|
sd_models.move_model(self.llm, devices.cpu)
|
|
devices.torch_gc()
|
|
raw_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
shared.log.trace(f'Prompt enhance: raw="{raw_response}"')
|
|
outputs = outputs[:, input_len:]
|
|
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
except Exception as e:
|
|
shared.log.error(f'Prompt enhance generate: {e}')
|
|
response = self.clean(response)
|
|
t1 = time.time()
|
|
shared.log.debug(f'Prompt enhance: model="{model}" time={t1-t0:.2f} inputs={input_len} outputs={outputs.shape[-1]} prompt="{response}"')
|
|
return response
|
|
|
|
|
|
def apply(self, prompt, apply_prompt, llm_model, prompt_system, max_tokens, do_sample, temperature, repetition_penalty):
|
|
response = self.enhance(
|
|
prompt=prompt,
|
|
model=llm_model,
|
|
system=prompt_system,
|
|
sample=do_sample,
|
|
tokens=max_tokens,
|
|
temperature=temperature,
|
|
penalty=repetition_penalty,
|
|
)
|
|
if apply_prompt:
|
|
return [response, response]
|
|
return [response, gr.update()]
|
|
|
|
def ui(self, _is_img2img):
|
|
with gr.Accordion('Prompt enhance', open=False, elem_id='prompt_enhance'):
|
|
with gr.Row():
|
|
apply_btn = gr.Button(value='Enhance now', elem_id='prompt_enhance_apply', variant='primary')
|
|
with gr.Row():
|
|
apply_prompt = gr.Checkbox(label='Apply to prompt', value=False)
|
|
apply_auto = gr.Checkbox(label='Auto enhance', value=False)
|
|
with gr.Group():
|
|
with gr.Row():
|
|
llm_model = gr.Dropdown(label='LLM model', choices=self.options.models, value=self.options.default, interactive=True, allow_custom_value=True, elem_id='prompt_enhance_model')
|
|
with gr.Row():
|
|
prompt_system = gr.Textbox(label='System prompt', value=self.options.system_prompt, interactive=True, lines=4, elem_id='prompt_enhance_system')
|
|
with gr.Row():
|
|
max_tokens = gr.Slider(label='Max tokens', value=self.options.max_tokens, minimum=10, maximum=1024, step=1, interactive=True)
|
|
do_sample = gr.Checkbox(label='Do sample', value=self.options.do_sample, interactive=True)
|
|
with gr.Row():
|
|
temperature = gr.Slider(label='Temperature', value=self.options.temperature, minimum=0.0, maximum=1.0, step=0.01, interactive=True)
|
|
repetition_penalty = gr.Slider(label='Repetition penalty', value=self.options.repetition_penalty, minimum=0.0, maximum=2.0, step=0.01, interactive=True)
|
|
with gr.Row():
|
|
prompt_output = gr.Textbox(label='Output', value='', interactive=True, lines=4)
|
|
apply_btn.click(fn=self.apply, inputs=[self.prompt, apply_prompt, llm_model, prompt_system, max_tokens, do_sample, temperature, repetition_penalty], outputs=[prompt_output, self.prompt])
|
|
return [apply_auto, llm_model, prompt_system, max_tokens, do_sample, temperature, repetition_penalty]
|
|
|
|
def after_component(self, component, **kwargs): # searching for actual ui prompt components
|
|
if getattr(component, 'elem_id', '') in ['txt2img_prompt', 'img2img_prompt', 'control_prompt', 'video_prompt']:
|
|
self.prompt = component
|
|
|
|
def before_process(self, p: processing.StableDiffusionProcessing, *args, **kwargs): # pylint: disable=unused-argument
|
|
apply_auto, llm_model, prompt_system, max_tokens, do_sample, temperature, repetition_penalty = args
|
|
if not apply_auto and not p.enhance_prompt:
|
|
return
|
|
p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
|
|
p.styles = []
|
|
p.prompt = self.enhance(
|
|
prompt=p.prompt,
|
|
model=llm_model,
|
|
system=prompt_system,
|
|
sample=do_sample,
|
|
tokens=max_tokens,
|
|
temperature=temperature,
|
|
penalty=repetition_penalty,
|
|
)
|