diff --git a/scripts/prompt_generator.py b/scripts/prompt_generator.py index 03d9ac1..f81d5e0 100644 --- a/scripts/prompt_generator.py +++ b/scripts/prompt_generator.py @@ -2,19 +2,21 @@ import gradio as gr import modules from modules import script_callbacks from transformers import GPT2Tokenizer, GPT2LMHeadModel -import os, re +import os +import re result_prompt = "" -def add_to_prompt(num):# A function that determines which prompt to pass - hand_over_prompt_list=result_prompt.splitlines() + +def add_to_prompt(num): # A function that determines which prompt to pass + hand_over_prompt_list = result_prompt.splitlines() try: - return(hand_over_prompt_list[int(num)-1][3:]) + return (hand_over_prompt_list[int(num)-1][3:]) except Exception as e: - print( - f"That line does not exist. Check number of prompts: {e}") - return gr.update(), f"Error: {e}" - + print( + f"That line does not exist. Check number of prompts: {e}") + return gr.update(), f"Error: {e}" + def get_list_blacklist(): # Set the directory you want to start from @@ -62,19 +64,21 @@ def on_ui_tabs(): generateButton = gr.Button( value="Generate", elem_id="generate_button") with gr.Column(visible=False) as results_col: - results = gr.Text(label="Results", elem_id="Results_textBox", interactive=False) + results = gr.Text( + label="Results", elem_id="Results_textBox", interactive=False) with gr.Column(visible=False) as promptNum_col: with gr.Row(): promptNum = gr.Textbox( lines=1, elem_id="promptNum", label="Send which prompt") with gr.Column(): - warning = gr.HTML(value="Select one number and send that prompt to txt2img or img2img", visible=False) + warning = gr.HTML( + value="Select one number and send that prompt to txt2img or img2img", visible=False) with gr.Row(): send_to_txt2img = gr.Button('Send to txt2img', visible=False) send_to_img2img = gr.Button('Send to img2img', visible=False) - - + # Method to create the extended prompt + def generate_longer_prompt(prompt, temperature, top_k, max_length, repetition_penalty, num_return_sequences, use_blacklist=False): try: @@ -84,6 +88,7 @@ def on_ui_tabs(): 'FredZhang7/distilgpt2-stable-diffusion-v2') except Exception as e: print(f"Exception encountered while attempting to install tokenizer") + return gr.update(), f"Error: {e}" try: print(f"Generate new prompt from: \"{prompt}\"") input_ids = tokenizer(prompt, return_tensors='pt').input_ids @@ -91,46 +96,54 @@ def on_ui_tabs(): top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repetition_penalty, - penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True) + penalty_alpha=0.6, no_repeat_ngram_size=1, + early_stopping=True) print("Generation complete!") tempString = "" - if(use_blacklist): - blacklist = get_list_blacklist() + if (use_blacklist): + blacklist = get_list_blacklist() for i in range(len(output)): tempString += str(i+1)+": "+tokenizer.decode( output[i], skip_special_tokens=True) + "\n" - - if(use_blacklist): + + if (use_blacklist): for to_check in blacklist: - tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE) - if(i==0): + tempString = re.sub( + to_check, "", tempString, flags=re.IGNORECASE) + if (i == 0): global result_prompt - + result_prompt = tempString print(result_prompt) return {results: tempString, - send_to_img2img: gr.update(visible = True), - send_to_txt2img: gr.update(visible = True), + send_to_img2img: gr.update(visible=True), + send_to_txt2img: gr.update(visible=True), results_col: gr.update(visible=True), warning: gr.update(visible=True), promptNum_col: gr.update(visible=True) - } + } except Exception as e: print( f"Exception encountered while attempting to generate prompt: {e}") return gr.update(), f"Error: {e}" - + # events generateButton.click(fn=generate_longer_prompt, inputs=[ promptTxt, temp_slider, top_k_slider, max_length_slider, - repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox], - outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning,promptNum_col]) - send_to_img2img.click(add_to_prompt,inputs=[promptNum], outputs=[img2img_prompt]) - send_to_txt2img.click(add_to_prompt,inputs=[promptNum], outputs=[txt2img_prompt]) - send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) - send_to_img2img.click(None, _js="switch_to_img2img", inputs=None, outputs=None) + repetition_penalty_slider, num_return_sequences_slider, + use_blacklist_checkbox], + outputs=[results, send_to_img2img, send_to_txt2img, + results_col, warning, promptNum_col]) + send_to_img2img.click(add_to_prompt, inputs=[ + promptNum], outputs=[img2img_prompt]) + send_to_txt2img.click(add_to_prompt, inputs=[ + promptNum], outputs=[txt2img_prompt]) + send_to_txt2img.click(None, _js='switch_to_txt2img', + inputs=None, outputs=None) + send_to_img2img.click(None, _js="switch_to_img2img", + inputs=None, outputs=None) return (prompt_generator, "Prompt Generator", "Prompt Generator"),