Fixed tokenizer being broken in some environments

pull/7/head
Imrayya 2023-01-11 18:00:50 +01:00
parent 33683c21ea
commit bcc7a3ee53
1 changed files with 54 additions and 54 deletions

View File

@ -31,8 +31,60 @@ def get_list_blacklist():
def on_ui_tabs():
# structure
# Method to create the extended prompt
def generate_longer_prompt(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
try:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained(
'FredZhang7/distilgpt2-stable-diffusion-v2')
except Exception as e:
print(f"Exception encountered while attempting to install tokenizer")
return gr.update(), f"Error: {e}"
try:
print(f"Generate new prompt from: \"{prompt}\"")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=top_k, max_length=max_length,
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty,
penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=True)
print("Generation complete!")
tempString = ""
if (use_blacklist):
blacklist = get_list_blacklist()
for i in range(len(output)):
tempString += str(i+1)+": "+tokenizer.decode(
output[i], skip_special_tokens=True) + "\n"
if (use_blacklist):
for to_check in blacklist:
tempString = re.sub(
to_check, "", tempString, flags=re.IGNORECASE)
if (i == 0):
global result_prompt
result_prompt = tempString
print(result_prompt)
return {results: tempString,
send_to_img2img: gr.update(visible=True),
send_to_txt2img: gr.update(visible=True),
results_col: gr.update(visible=True),
warning: gr.update(visible=True),
promptNum_col: gr.update(visible=True)
}
except Exception as e:
print(
f"Exception encountered while attempting to generate prompt: {e}")
return gr.update(), f"Error: {e}"
# structure
txt2img_prompt = modules.ui.txt2img_paste_fields[0][0]
img2img_prompt = modules.ui.img2img_paste_fields[0][0]
@ -77,58 +129,7 @@ def on_ui_tabs():
send_to_txt2img = gr.Button('Send to txt2img', visible=False)
send_to_img2img = gr.Button('Send to img2img', visible=False)
# Method to create the extended prompt
def generate_longer_prompt(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
try:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained(
'FredZhang7/distilgpt2-stable-diffusion-v2')
except Exception as e:
print(f"Exception encountered while attempting to install tokenizer")
return gr.update(), f"Error: {e}"
try:
print(f"Generate new prompt from: \"{prompt}\"")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=top_k, max_length=max_length,
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty,
penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=True)
print("Generation complete!")
tempString = ""
if (use_blacklist):
blacklist = get_list_blacklist()
for i in range(len(output)):
tempString += str(i+1)+": "+tokenizer.decode(
output[i], skip_special_tokens=True) + "\n"
if (use_blacklist):
for to_check in blacklist:
tempString = re.sub(
to_check, "", tempString, flags=re.IGNORECASE)
if (i == 0):
global result_prompt
result_prompt = tempString
print(result_prompt)
return {results: tempString,
send_to_img2img: gr.update(visible=True),
send_to_txt2img: gr.update(visible=True),
results_col: gr.update(visible=True),
warning: gr.update(visible=True),
promptNum_col: gr.update(visible=True)
}
except Exception as e:
print(
f"Exception encountered while attempting to generate prompt: {e}")
return gr.update(), f"Error: {e}"
# events
generateButton.click(fn=generate_longer_prompt, inputs=[
promptTxt, temp_slider, top_k_slider, max_length_slider,
@ -146,5 +147,4 @@ def on_ui_tabs():
inputs=None, outputs=None)
return (prompt_generator, "Prompt Generator", "Prompt Generator"),
script_callbacks.on_ui_tabs(on_ui_tabs)