Added ability to use a blacklist

pull/3/head
Imrayya 2023-01-07 19:04:21 +01:00
parent 448f1a3d89
commit 961c71b26d
2 changed files with 29 additions and 6 deletions

0
blacklist.txt Normal file
View File

View File

@ -2,12 +2,25 @@ import gradio as gr
import modules import modules
from modules import script_callbacks from modules import script_callbacks
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os, re
first_gen = "" first_gen = ""
def add_to_prompt(): def add_to_prompt():
return first_gen return first_gen
def get_list_blacklist():
# Set the directory you want to start from
file_path = './extensions/prompt-maker/blacklist.txt'
things_to_black_list = []
with open(file_path, 'r') as f:
# Read each line in the file and append it to the list
for line in f:
things_to_black_list.append(line)
return things_to_black_list
def on_ui_tabs(): def on_ui_tabs():
# structure # structure
@ -18,7 +31,7 @@ def on_ui_tabs():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
promptTxt = gr.Textbox( promptTxt = gr.Textbox(
lines=2, elem_id="promptTxt", label="Start of the prompt") lines=2, elem_id="promptTxt", label="Start of the prompt", tooltip="Hello?")
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
temp_slider = gr.Slider( temp_slider = gr.Slider(
@ -33,6 +46,10 @@ def on_ui_tabs():
elem_id="repetition_penalty_slider", label="Repetition Penalty", value=1.2, minimum=0, maximum=10, interactive=True) elem_id="repetition_penalty_slider", label="Repetition Penalty", value=1.2, minimum=0, maximum=10, interactive=True)
num_return_sequences_slider = gr.Slider( num_return_sequences_slider = gr.Slider(
elem_id="num_return_sequences_slider", label="How Many To Generate", value=5, minimum=1, maximum=20, interactive=True, step=1) elem_id="num_return_sequences_slider", label="How Many To Generate", value=5, minimum=1, maximum=20, interactive=True, step=1)
with gr.Column():
with gr.Row():
use_blacklist_checkbox = gr.Checkbox(label="Use blacklist?")
gr.HTML(value="<center>Using <code>\".\extensions\prompt-maker\\blacklist.txt</code>\".<br>It will delete any matches to the generated result (case insensitive).</center>")
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
generateButton = gr.Button( generateButton = gr.Button(
@ -46,9 +63,9 @@ def on_ui_tabs():
send_to_img2img = gr.Button('Send to img2img', visible=False) send_to_img2img = gr.Button('Send to img2img', visible=False)
# events # Method to create the extended prompt
def generate_longer_prompt(prompt, temperature, top_k, def generate_longer_prompt(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences): max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
try: try:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2') tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) tokenizer.add_special_tokens({'pad_token': '[PAD]'})
@ -66,11 +83,15 @@ def on_ui_tabs():
penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True) penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
print("Generation complete!") print("Generation complete!")
tempString = "" tempString = ""
if(use_blacklist):
blacklist = get_list_blacklist()
for i in range(len(output)): for i in range(len(output)):
tempString += tokenizer.decode( tempString += tokenizer.decode(
output[i], skip_special_tokens=True) + "\n" output[i], skip_special_tokens=True) + "\n"
if(use_blacklist):
for to_check in blacklist:
tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE)
if(i==0): if(i==0):
global first_gen global first_gen
first_gen = tempString first_gen = tempString
@ -85,9 +106,11 @@ def on_ui_tabs():
print( print(
f"Exception encountered while attempting to generate prompt: {e}") f"Exception encountered while attempting to generate prompt: {e}")
return gr.update(), f"Error: {e}" return gr.update(), f"Error: {e}"
# events
generateButton.click(fn=generate_longer_prompt, inputs=[ generateButton.click(fn=generate_longer_prompt, inputs=[
promptTxt, temp_slider, top_k_slider, max_length_slider, promptTxt, temp_slider, top_k_slider, max_length_slider,
repetition_penalty_slider, num_return_sequences_slider], repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox],
outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning]) outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning])
send_to_img2img.click(add_to_prompt, outputs=[img2img_prompt]) send_to_img2img.click(add_to_prompt, outputs=[img2img_prompt])
send_to_txt2img.click(add_to_prompt, outputs=[txt2img_prompt]) send_to_txt2img.click(add_to_prompt, outputs=[txt2img_prompt])