Added ability to use a blacklist
parent
448f1a3d89
commit
961c71b26d
|
|
@ -2,12 +2,25 @@ import gradio as gr
|
||||||
import modules
|
import modules
|
||||||
from modules import script_callbacks
|
from modules import script_callbacks
|
||||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
import os, re
|
||||||
|
|
||||||
first_gen = ""
|
first_gen = ""
|
||||||
|
|
||||||
def add_to_prompt():
|
def add_to_prompt():
|
||||||
return first_gen
|
return first_gen
|
||||||
|
|
||||||
|
def get_list_blacklist():
|
||||||
|
# Set the directory you want to start from
|
||||||
|
file_path = './extensions/prompt-maker/blacklist.txt'
|
||||||
|
things_to_black_list = []
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
# Read each line in the file and append it to the list
|
||||||
|
for line in f:
|
||||||
|
things_to_black_list.append(line)
|
||||||
|
|
||||||
|
return things_to_black_list
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs():
|
def on_ui_tabs():
|
||||||
# structure
|
# structure
|
||||||
|
|
||||||
|
|
@ -18,7 +31,7 @@ def on_ui_tabs():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
promptTxt = gr.Textbox(
|
promptTxt = gr.Textbox(
|
||||||
lines=2, elem_id="promptTxt", label="Start of the prompt")
|
lines=2, elem_id="promptTxt", label="Start of the prompt", tooltip="Hello?")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
temp_slider = gr.Slider(
|
temp_slider = gr.Slider(
|
||||||
|
|
@ -33,6 +46,10 @@ def on_ui_tabs():
|
||||||
elem_id="repetition_penalty_slider", label="Repetition Penalty", value=1.2, minimum=0, maximum=10, interactive=True)
|
elem_id="repetition_penalty_slider", label="Repetition Penalty", value=1.2, minimum=0, maximum=10, interactive=True)
|
||||||
num_return_sequences_slider = gr.Slider(
|
num_return_sequences_slider = gr.Slider(
|
||||||
elem_id="num_return_sequences_slider", label="How Many To Generate", value=5, minimum=1, maximum=20, interactive=True, step=1)
|
elem_id="num_return_sequences_slider", label="How Many To Generate", value=5, minimum=1, maximum=20, interactive=True, step=1)
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
use_blacklist_checkbox = gr.Checkbox(label="Use blacklist?")
|
||||||
|
gr.HTML(value="<center>Using <code>\".\extensions\prompt-maker\\blacklist.txt</code>\".<br>It will delete any matches to the generated result (case insensitive).</center>")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
generateButton = gr.Button(
|
generateButton = gr.Button(
|
||||||
|
|
@ -46,9 +63,9 @@ def on_ui_tabs():
|
||||||
send_to_img2img = gr.Button('Send to img2img', visible=False)
|
send_to_img2img = gr.Button('Send to img2img', visible=False)
|
||||||
|
|
||||||
|
|
||||||
# events
|
# Method to create the extended prompt
|
||||||
def generate_longer_prompt(prompt, temperature, top_k,
|
def generate_longer_prompt(prompt, temperature, top_k,
|
||||||
max_length, repetition_penalty, num_return_sequences):
|
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||||
|
|
@ -66,15 +83,19 @@ def on_ui_tabs():
|
||||||
penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
|
penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
|
||||||
print("Generation complete!")
|
print("Generation complete!")
|
||||||
tempString = ""
|
tempString = ""
|
||||||
|
if(use_blacklist):
|
||||||
|
blacklist = get_list_blacklist()
|
||||||
for i in range(len(output)):
|
for i in range(len(output)):
|
||||||
|
|
||||||
tempString += tokenizer.decode(
|
tempString += tokenizer.decode(
|
||||||
output[i], skip_special_tokens=True) + "\n"
|
output[i], skip_special_tokens=True) + "\n"
|
||||||
|
if(use_blacklist):
|
||||||
|
for to_check in blacklist:
|
||||||
|
tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE)
|
||||||
if(i==0):
|
if(i==0):
|
||||||
global first_gen
|
global first_gen
|
||||||
first_gen = tempString
|
first_gen = tempString
|
||||||
|
|
||||||
return {results: tempString,
|
return {results: tempString,
|
||||||
send_to_img2img: gr.update(visible = True),
|
send_to_img2img: gr.update(visible = True),
|
||||||
send_to_txt2img: gr.update(visible = True),
|
send_to_txt2img: gr.update(visible = True),
|
||||||
|
|
@ -85,9 +106,11 @@ def on_ui_tabs():
|
||||||
print(
|
print(
|
||||||
f"Exception encountered while attempting to generate prompt: {e}")
|
f"Exception encountered while attempting to generate prompt: {e}")
|
||||||
return gr.update(), f"Error: {e}"
|
return gr.update(), f"Error: {e}"
|
||||||
|
|
||||||
|
# events
|
||||||
generateButton.click(fn=generate_longer_prompt, inputs=[
|
generateButton.click(fn=generate_longer_prompt, inputs=[
|
||||||
promptTxt, temp_slider, top_k_slider, max_length_slider,
|
promptTxt, temp_slider, top_k_slider, max_length_slider,
|
||||||
repetition_penalty_slider, num_return_sequences_slider],
|
repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox],
|
||||||
outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning])
|
outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning])
|
||||||
send_to_img2img.click(add_to_prompt, outputs=[img2img_prompt])
|
send_to_img2img.click(add_to_prompt, outputs=[img2img_prompt])
|
||||||
send_to_txt2img.click(add_to_prompt, outputs=[txt2img_prompt])
|
send_to_txt2img.click(add_to_prompt, outputs=[txt2img_prompt])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue