Fixed tokenizer being broken in some environments
parent
33683c21ea
commit
bcc7a3ee53
|
|
@ -31,8 +31,60 @@ def get_list_blacklist():
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs():
|
def on_ui_tabs():
|
||||||
# structure
|
# Method to create the extended prompt
|
||||||
|
|
||||||
|
def generate_longer_prompt(prompt, temperature, top_k,
|
||||||
|
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
|
||||||
|
try:
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||||
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(
|
||||||
|
'FredZhang7/distilgpt2-stable-diffusion-v2')
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception encountered while attempting to install tokenizer")
|
||||||
|
return gr.update(), f"Error: {e}"
|
||||||
|
try:
|
||||||
|
print(f"Generate new prompt from: \"{prompt}\"")
|
||||||
|
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
||||||
|
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
||||||
|
top_k=top_k, max_length=max_length,
|
||||||
|
num_return_sequences=num_return_sequences,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
penalty_alpha=0.6, no_repeat_ngram_size=1,
|
||||||
|
early_stopping=True)
|
||||||
|
print("Generation complete!")
|
||||||
|
tempString = ""
|
||||||
|
if (use_blacklist):
|
||||||
|
blacklist = get_list_blacklist()
|
||||||
|
for i in range(len(output)):
|
||||||
|
|
||||||
|
tempString += str(i+1)+": "+tokenizer.decode(
|
||||||
|
output[i], skip_special_tokens=True) + "\n"
|
||||||
|
|
||||||
|
if (use_blacklist):
|
||||||
|
for to_check in blacklist:
|
||||||
|
tempString = re.sub(
|
||||||
|
to_check, "", tempString, flags=re.IGNORECASE)
|
||||||
|
if (i == 0):
|
||||||
|
global result_prompt
|
||||||
|
|
||||||
|
result_prompt = tempString
|
||||||
|
print(result_prompt)
|
||||||
|
|
||||||
|
return {results: tempString,
|
||||||
|
send_to_img2img: gr.update(visible=True),
|
||||||
|
send_to_txt2img: gr.update(visible=True),
|
||||||
|
results_col: gr.update(visible=True),
|
||||||
|
warning: gr.update(visible=True),
|
||||||
|
promptNum_col: gr.update(visible=True)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Exception encountered while attempting to generate prompt: {e}")
|
||||||
|
return gr.update(), f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
# structure
|
||||||
txt2img_prompt = modules.ui.txt2img_paste_fields[0][0]
|
txt2img_prompt = modules.ui.txt2img_paste_fields[0][0]
|
||||||
img2img_prompt = modules.ui.img2img_paste_fields[0][0]
|
img2img_prompt = modules.ui.img2img_paste_fields[0][0]
|
||||||
|
|
||||||
|
|
@ -77,57 +129,6 @@ def on_ui_tabs():
|
||||||
send_to_txt2img = gr.Button('Send to txt2img', visible=False)
|
send_to_txt2img = gr.Button('Send to txt2img', visible=False)
|
||||||
send_to_img2img = gr.Button('Send to img2img', visible=False)
|
send_to_img2img = gr.Button('Send to img2img', visible=False)
|
||||||
|
|
||||||
# Method to create the extended prompt
|
|
||||||
|
|
||||||
def generate_longer_prompt(prompt, temperature, top_k,
|
|
||||||
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
|
|
||||||
try:
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
|
||||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
||||||
model = GPT2LMHeadModel.from_pretrained(
|
|
||||||
'FredZhang7/distilgpt2-stable-diffusion-v2')
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception encountered while attempting to install tokenizer")
|
|
||||||
return gr.update(), f"Error: {e}"
|
|
||||||
try:
|
|
||||||
print(f"Generate new prompt from: \"{prompt}\"")
|
|
||||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
|
||||||
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
|
||||||
top_k=top_k, max_length=max_length,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
penalty_alpha=0.6, no_repeat_ngram_size=1,
|
|
||||||
early_stopping=True)
|
|
||||||
print("Generation complete!")
|
|
||||||
tempString = ""
|
|
||||||
if (use_blacklist):
|
|
||||||
blacklist = get_list_blacklist()
|
|
||||||
for i in range(len(output)):
|
|
||||||
|
|
||||||
tempString += str(i+1)+": "+tokenizer.decode(
|
|
||||||
output[i], skip_special_tokens=True) + "\n"
|
|
||||||
|
|
||||||
if (use_blacklist):
|
|
||||||
for to_check in blacklist:
|
|
||||||
tempString = re.sub(
|
|
||||||
to_check, "", tempString, flags=re.IGNORECASE)
|
|
||||||
if (i == 0):
|
|
||||||
global result_prompt
|
|
||||||
|
|
||||||
result_prompt = tempString
|
|
||||||
print(result_prompt)
|
|
||||||
|
|
||||||
return {results: tempString,
|
|
||||||
send_to_img2img: gr.update(visible=True),
|
|
||||||
send_to_txt2img: gr.update(visible=True),
|
|
||||||
results_col: gr.update(visible=True),
|
|
||||||
warning: gr.update(visible=True),
|
|
||||||
promptNum_col: gr.update(visible=True)
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"Exception encountered while attempting to generate prompt: {e}")
|
|
||||||
return gr.update(), f"Error: {e}"
|
|
||||||
|
|
||||||
# events
|
# events
|
||||||
generateButton.click(fn=generate_longer_prompt, inputs=[
|
generateButton.click(fn=generate_longer_prompt, inputs=[
|
||||||
|
|
@ -146,5 +147,4 @@ def on_ui_tabs():
|
||||||
inputs=None, outputs=None)
|
inputs=None, outputs=None)
|
||||||
return (prompt_generator, "Prompt Generator", "Prompt Generator"),
|
return (prompt_generator, "Prompt Generator", "Prompt Generator"),
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue