Slightly better error handling. Bit of formatting

pull/7/head
Imrayya 2023-01-11 13:24:07 +01:00
parent aeaeabd04f
commit db93f06a21
1 changed files with 43 additions and 30 deletions

View File

@ -2,14 +2,16 @@ import gradio as gr
import modules import modules
from modules import script_callbacks from modules import script_callbacks
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os, re import os
import re
result_prompt = "" result_prompt = ""
def add_to_prompt(num):# A function that determines which prompt to pass
hand_over_prompt_list=result_prompt.splitlines() def add_to_prompt(num): # A function that determines which prompt to pass
hand_over_prompt_list = result_prompt.splitlines()
try: try:
return(hand_over_prompt_list[int(num)-1][3:]) return (hand_over_prompt_list[int(num)-1][3:])
except Exception as e: except Exception as e:
print( print(
f"That line does not exist. Check number of prompts: {e}") f"That line does not exist. Check number of prompts: {e}")
@ -62,19 +64,21 @@ def on_ui_tabs():
generateButton = gr.Button( generateButton = gr.Button(
value="Generate", elem_id="generate_button") value="Generate", elem_id="generate_button")
with gr.Column(visible=False) as results_col: with gr.Column(visible=False) as results_col:
results = gr.Text(label="Results", elem_id="Results_textBox", interactive=False) results = gr.Text(
label="Results", elem_id="Results_textBox", interactive=False)
with gr.Column(visible=False) as promptNum_col: with gr.Column(visible=False) as promptNum_col:
with gr.Row(): with gr.Row():
promptNum = gr.Textbox( promptNum = gr.Textbox(
lines=1, elem_id="promptNum", label="Send which prompt") lines=1, elem_id="promptNum", label="Send which prompt")
with gr.Column(): with gr.Column():
warning = gr.HTML(value="Select one number and send that prompt to txt2img or img2img", visible=False) warning = gr.HTML(
value="Select one number and send that prompt to txt2img or img2img", visible=False)
with gr.Row(): with gr.Row():
send_to_txt2img = gr.Button('Send to txt2img', visible=False) send_to_txt2img = gr.Button('Send to txt2img', visible=False)
send_to_img2img = gr.Button('Send to img2img', visible=False) send_to_img2img = gr.Button('Send to img2img', visible=False)
# Method to create the extended prompt # Method to create the extended prompt
def generate_longer_prompt(prompt, temperature, top_k, def generate_longer_prompt(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, use_blacklist=False): max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
try: try:
@ -84,6 +88,7 @@ def on_ui_tabs():
'FredZhang7/distilgpt2-stable-diffusion-v2') 'FredZhang7/distilgpt2-stable-diffusion-v2')
except Exception as e: except Exception as e:
print(f"Exception encountered while attempting to install tokenizer") print(f"Exception encountered while attempting to install tokenizer")
return gr.update(), f"Error: {e}"
try: try:
print(f"Generate new prompt from: \"{prompt}\"") print(f"Generate new prompt from: \"{prompt}\"")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids input_ids = tokenizer(prompt, return_tensors='pt').input_ids
@ -91,28 +96,30 @@ def on_ui_tabs():
top_k=top_k, max_length=max_length, top_k=top_k, max_length=max_length,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True) penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=True)
print("Generation complete!") print("Generation complete!")
tempString = "" tempString = ""
if(use_blacklist): if (use_blacklist):
blacklist = get_list_blacklist() blacklist = get_list_blacklist()
for i in range(len(output)): for i in range(len(output)):
tempString += str(i+1)+": "+tokenizer.decode( tempString += str(i+1)+": "+tokenizer.decode(
output[i], skip_special_tokens=True) + "\n" output[i], skip_special_tokens=True) + "\n"
if(use_blacklist): if (use_blacklist):
for to_check in blacklist: for to_check in blacklist:
tempString = re.sub(to_check, "", tempString, flags=re.IGNORECASE) tempString = re.sub(
if(i==0): to_check, "", tempString, flags=re.IGNORECASE)
if (i == 0):
global result_prompt global result_prompt
result_prompt = tempString result_prompt = tempString
print(result_prompt) print(result_prompt)
return {results: tempString, return {results: tempString,
send_to_img2img: gr.update(visible = True), send_to_img2img: gr.update(visible=True),
send_to_txt2img: gr.update(visible = True), send_to_txt2img: gr.update(visible=True),
results_col: gr.update(visible=True), results_col: gr.update(visible=True),
warning: gr.update(visible=True), warning: gr.update(visible=True),
promptNum_col: gr.update(visible=True) promptNum_col: gr.update(visible=True)
@ -125,12 +132,18 @@ def on_ui_tabs():
# events # events
generateButton.click(fn=generate_longer_prompt, inputs=[ generateButton.click(fn=generate_longer_prompt, inputs=[
promptTxt, temp_slider, top_k_slider, max_length_slider, promptTxt, temp_slider, top_k_slider, max_length_slider,
repetition_penalty_slider, num_return_sequences_slider, use_blacklist_checkbox], repetition_penalty_slider, num_return_sequences_slider,
outputs=[results, send_to_img2img, send_to_txt2img, results_col, warning,promptNum_col]) use_blacklist_checkbox],
send_to_img2img.click(add_to_prompt,inputs=[promptNum], outputs=[img2img_prompt]) outputs=[results, send_to_img2img, send_to_txt2img,
send_to_txt2img.click(add_to_prompt,inputs=[promptNum], outputs=[txt2img_prompt]) results_col, warning, promptNum_col])
send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) send_to_img2img.click(add_to_prompt, inputs=[
send_to_img2img.click(None, _js="switch_to_img2img", inputs=None, outputs=None) promptNum], outputs=[img2img_prompt])
send_to_txt2img.click(add_to_prompt, inputs=[
promptNum], outputs=[txt2img_prompt])
send_to_txt2img.click(None, _js='switch_to_txt2img',
inputs=None, outputs=None)
send_to_img2img.click(None, _js="switch_to_img2img",
inputs=None, outputs=None)
return (prompt_generator, "Prompt Generator", "Prompt Generator"), return (prompt_generator, "Prompt Generator", "Prompt Generator"),