81 lines
3.7 KiB
Python
81 lines
3.7 KiB
Python
import re
|
|
|
|
import gradio as gr
|
|
from modules import script_callbacks
|
|
from modules import generation_parameters_copypaste as params_copypaste
|
|
|
|
import scripts.t2p.prompt_generator as pgen
|
|
|
|
wd_like = None
|
|
|
|
# brought from modules/deepbooru.py
|
|
re_special = re.compile(r'([\\()])')
|
|
|
|
def get_conversion(choice: int):
|
|
if choice == 0: return pgen.ProbablityConversion.CUTOFF_AND_POWER
|
|
elif choice == 1: return pgen.ProbablityConversion.SOFTMAX
|
|
else: raise NotImplementedError()
|
|
|
|
def get_sampling(choice: int):
|
|
if choice == 0: return pgen.SamplingMethod.NONE
|
|
elif choice == 1: return pgen.SamplingMethod.TOP_K
|
|
elif choice == 2: return pgen.SamplingMethod.TOP_P
|
|
else: raise NotImplementedError()
|
|
|
|
def generate_prompt(text: str, conversion: int, power: float, sampling: int, n: int, k: int, p: float, weighted: bool, replace_underscore: bool, excape_brackets: bool):
|
|
global wd_like
|
|
if not wd_like:
|
|
print('Loading tag data and model')
|
|
wd_like = pgen.WDLike()
|
|
wd_like.load_model()
|
|
print('Loaded')
|
|
tags = wd_like(text, pgen.GenerationSettings(get_conversion(conversion), power, get_sampling(sampling), n, k, p, weighted))
|
|
if replace_underscore: tags = [t.replace('_', ' ') for t in tags]
|
|
if excape_brackets: tags = [re.sub(re_special, r'\\\1', t) for t in tags]
|
|
return ', '.join(tags)
|
|
|
|
def on_ui_tabs():
|
|
with gr.Blocks(analytics_enabled=False) as text2prompt_interface:
|
|
with gr.Row():
|
|
with gr.Column():
|
|
tb_input = gr.Textbox(label='Input Theme', interactive=True)
|
|
cb_replace_underscore = gr.Checkbox(value=True, label='Replace underscore in tag with whitespace', interactive=True)
|
|
cb_escape_brackets = gr.Checkbox(value=True, label='Escape brackets in tag', interactive=True)
|
|
btn_generate = gr.Button(value='Generate', variant='primary')
|
|
tb_output = gr.Textbox(label='Output', interactive=True)
|
|
with gr.Row():
|
|
buttons = params_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
|
params_copypaste.bind_buttons(buttons, None, tb_output)
|
|
|
|
with gr.Column():
|
|
gr.HTML(value='Generation Settings')
|
|
rb_prob_conversion_method = gr.Radio(choices=['Cutoff and Power', 'Softmax'], value='Cutoff and Power', type='index', label='Method to convert similarity into probablity')
|
|
sl_power = gr.Slider(0, 5, value=2, step=0.1, label='Power', interactive=True)
|
|
rb_sampling_method = gr.Radio(choices=['NONE', 'Top-k', 'Top-p (Nucleus)'], value='Top-k', type='index', label='Sampling method')
|
|
nb_max_tag_num = gr.Number(value=20, label='Max number of tags', precision=0, interactive=True)
|
|
nb_k_value = gr.Number(value=50, label='k value', precision=0, interactive=True)
|
|
sl_p_value = gr.Slider(0, 1, label='p value', value=0.1, step=0.01, interactive=True)
|
|
cb_weighted = gr.Checkbox(value=True, label='Use weighted choice', interactive=True)
|
|
|
|
nb_max_tag_num.change(
|
|
fn=lambda x: max(0, x),
|
|
inputs=nb_max_tag_num,
|
|
outputs=nb_max_tag_num
|
|
)
|
|
|
|
nb_k_value.change(
|
|
fn=lambda x: max(1, x),
|
|
inputs=nb_k_value,
|
|
outputs=nb_k_value
|
|
)
|
|
|
|
btn_generate.click(
|
|
fn=generate_prompt,
|
|
inputs=[tb_input, rb_prob_conversion_method, sl_power, rb_sampling_method, nb_max_tag_num, nb_k_value, sl_p_value, cb_weighted, cb_replace_underscore, cb_escape_brackets],
|
|
outputs=tb_output
|
|
)
|
|
|
|
return [(text2prompt_interface, "Text2Prompt", "text2prompt_interface")]
|
|
|
|
|
|
script_callbacks.on_ui_tabs(on_ui_tabs) |