diff --git a/blacklist.txt b/blacklist.txt new file mode 100644 index 0000000..e69de29 diff --git a/scripts/prompt_generator.py b/scripts/prompt_generator.py index 2cfccc2..f43371d 100644 --- a/scripts/prompt_generator.py +++ b/scripts/prompt_generator.py @@ -2,12 +2,25 @@ import gradio as gr import modules from modules import script_callbacks from transformers import GPT2Tokenizer, GPT2LMHeadModel +import os, re first_gen = "" def add_to_prompt(): return first_gen +def get_list_blacklist(): + # Set the directory you want to start from + file_path = './extensions/prompt-maker/blacklist.txt' + things_to_black_list = [] + with open(file_path, 'r') as f: + # Read each line in the file and append it to the list + for line in f: + things_to_black_list.append(line) + + return things_to_black_list + + def on_ui_tabs(): # structure @@ -18,7 +31,7 @@ def on_ui_tabs(): with gr.Column(): with gr.Row(): promptTxt = gr.Textbox( - lines=2, elem_id="promptTxt", label="Start of the prompt") + lines=2, elem_id="promptTxt", label="Start of the prompt", tooltip="Hello?") with gr.Column(): with gr.Row(): temp_slider = gr.Slider( @@ -33,6 +46,10 @@ def on_ui_tabs(): elem_id="repetition_penalty_slider", label="Repetition Penalty", value=1.2, minimum=0, maximum=10, interactive=True) num_return_sequences_slider = gr.Slider( elem_id="num_return_sequences_slider", label="How Many To Generate", value=5, minimum=1, maximum=20, interactive=True, step=1) + with gr.Column(): + with gr.Row(): + use_blacklist_checkbox = gr.Checkbox(label="Use blacklist?") + gr.HTML(value="
Using \".\extensions\prompt-maker\\blacklist.txt\".
It will delete any matches to the generated result (case insensitive).
") with gr.Column(): with gr.Row(): generateButton = gr.Button( @@ -46,9 +63,9 @@ def on_ui_tabs(): send_to_img2img = gr.Button('Send to img2img', visible=False) - # events + # Method to create the extended prompt def generate_longer_prompt(prompt, temperature, top_k, - max_length, repetition_penalty, num_return_sequences): + max_length, repetition_penalty, num_return_sequences, use_blacklist=False): try: tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2') tokenizer.add_special_tokens({'pad_token': '[PAD]'}) @@ -66,15 +83,19 @@ def on_ui_tabs(): penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True) print("Generation complete!") tempString = "" - + if(use_blacklist): + blacklist = get_list_blacklist() for i in range(len(output)): tempString += tokenizer.decode( output[i], skip_special_tokens=True) + "\n" + if(use_blacklist): + for to_check in blacklist: + tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE) if(i==0): global first_gen first_gen = tempString - + return {results: tempString, send_to_img2img: gr.update(visible = True), send_to_txt2img: gr.update(visible = True), @@ -85,9 +106,11 @@ def on_ui_tabs(): print( f"Exception encountered while attempting to generate prompt: {e}") return gr.update(), f"Error: {e}" + + # events generateButton.click(fn=generate_longer_prompt, inputs=[ promptTxt, temp_slider, top_k_slider, max_length_slider, - repetition_penalty_slider, num_return_sequences_slider], + repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox], outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning]) send_to_img2img.click(add_to_prompt, outputs=[img2img_prompt]) send_to_txt2img.click(add_to_prompt, outputs=[txt2img_prompt])