Slightly better error handling. Bit of formatting
parent
aeaeabd04f
commit
db93f06a21
|
|
@ -2,18 +2,20 @@ import gradio as gr
|
|||
import modules
|
||||
from modules import script_callbacks
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import os, re
|
||||
import os
|
||||
import re
|
||||
|
||||
result_prompt = ""
|
||||
|
||||
def add_to_prompt(num):# A function that determines which prompt to pass
|
||||
hand_over_prompt_list=result_prompt.splitlines()
|
||||
|
||||
def add_to_prompt(num): # A function that determines which prompt to pass
|
||||
hand_over_prompt_list = result_prompt.splitlines()
|
||||
try:
|
||||
return(hand_over_prompt_list[int(num)-1][3:])
|
||||
return (hand_over_prompt_list[int(num)-1][3:])
|
||||
except Exception as e:
|
||||
print(
|
||||
f"That line does not exist. Check number of prompts: {e}")
|
||||
return gr.update(), f"Error: {e}"
|
||||
print(
|
||||
f"That line does not exist. Check number of prompts: {e}")
|
||||
return gr.update(), f"Error: {e}"
|
||||
|
||||
|
||||
def get_list_blacklist():
|
||||
|
|
@ -62,19 +64,21 @@ def on_ui_tabs():
|
|||
generateButton = gr.Button(
|
||||
value="Generate", elem_id="generate_button")
|
||||
with gr.Column(visible=False) as results_col:
|
||||
results = gr.Text(label="Results", elem_id="Results_textBox", interactive=False)
|
||||
results = gr.Text(
|
||||
label="Results", elem_id="Results_textBox", interactive=False)
|
||||
with gr.Column(visible=False) as promptNum_col:
|
||||
with gr.Row():
|
||||
promptNum = gr.Textbox(
|
||||
lines=1, elem_id="promptNum", label="Send which prompt")
|
||||
with gr.Column():
|
||||
warning = gr.HTML(value="Select one number and send that prompt to txt2img or img2img", visible=False)
|
||||
warning = gr.HTML(
|
||||
value="Select one number and send that prompt to txt2img or img2img", visible=False)
|
||||
with gr.Row():
|
||||
send_to_txt2img = gr.Button('Send to txt2img', visible=False)
|
||||
send_to_img2img = gr.Button('Send to img2img', visible=False)
|
||||
|
||||
|
||||
# Method to create the extended prompt
|
||||
|
||||
def generate_longer_prompt(prompt, temperature, top_k,
|
||||
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
|
||||
try:
|
||||
|
|
@ -84,6 +88,7 @@ def on_ui_tabs():
|
|||
'FredZhang7/distilgpt2-stable-diffusion-v2')
|
||||
except Exception as e:
|
||||
print(f"Exception encountered while attempting to install tokenizer")
|
||||
return gr.update(), f"Error: {e}"
|
||||
try:
|
||||
print(f"Generate new prompt from: \"{prompt}\"")
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
||||
|
|
@ -91,32 +96,34 @@ def on_ui_tabs():
|
|||
top_k=top_k, max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
|
||||
penalty_alpha=0.6, no_repeat_ngram_size=1,
|
||||
early_stopping=True)
|
||||
print("Generation complete!")
|
||||
tempString = ""
|
||||
if(use_blacklist):
|
||||
blacklist = get_list_blacklist()
|
||||
if (use_blacklist):
|
||||
blacklist = get_list_blacklist()
|
||||
for i in range(len(output)):
|
||||
|
||||
tempString += str(i+1)+": "+tokenizer.decode(
|
||||
output[i], skip_special_tokens=True) + "\n"
|
||||
|
||||
if(use_blacklist):
|
||||
if (use_blacklist):
|
||||
for to_check in blacklist:
|
||||
tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE)
|
||||
if(i==0):
|
||||
tempString = re.sub(
|
||||
to_check, "", tempString, flags=re.IGNORECASE)
|
||||
if (i == 0):
|
||||
global result_prompt
|
||||
|
||||
result_prompt = tempString
|
||||
print(result_prompt)
|
||||
|
||||
return {results: tempString,
|
||||
send_to_img2img: gr.update(visible = True),
|
||||
send_to_txt2img: gr.update(visible = True),
|
||||
send_to_img2img: gr.update(visible=True),
|
||||
send_to_txt2img: gr.update(visible=True),
|
||||
results_col: gr.update(visible=True),
|
||||
warning: gr.update(visible=True),
|
||||
promptNum_col: gr.update(visible=True)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Exception encountered while attempting to generate prompt: {e}")
|
||||
|
|
@ -125,12 +132,18 @@ def on_ui_tabs():
|
|||
# events
|
||||
generateButton.click(fn=generate_longer_prompt, inputs=[
|
||||
promptTxt, temp_slider, top_k_slider, max_length_slider,
|
||||
repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox],
|
||||
outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning,promptNum_col])
|
||||
send_to_img2img.click(add_to_prompt,inputs=[promptNum], outputs=[img2img_prompt])
|
||||
send_to_txt2img.click(add_to_prompt,inputs=[promptNum], outputs=[txt2img_prompt])
|
||||
send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None)
|
||||
send_to_img2img.click(None, _js="switch_to_img2img", inputs=None, outputs=None)
|
||||
repetition_penalty_slider, num_return_sequences_slider,
|
||||
use_blacklist_checkbox],
|
||||
outputs=[results, send_to_img2img, send_to_txt2img,
|
||||
results_col, warning, promptNum_col])
|
||||
send_to_img2img.click(add_to_prompt, inputs=[
|
||||
promptNum], outputs=[img2img_prompt])
|
||||
send_to_txt2img.click(add_to_prompt, inputs=[
|
||||
promptNum], outputs=[txt2img_prompt])
|
||||
send_to_txt2img.click(None, _js='switch_to_txt2img',
|
||||
inputs=None, outputs=None)
|
||||
send_to_img2img.click(None, _js="switch_to_img2img",
|
||||
inputs=None, outputs=None)
|
||||
return (prompt_generator, "Prompt Generator", "Prompt Generator"),
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue