From bcc7a3ee53c0f8e9ce284d88be1f67b86cce4e3b Mon Sep 17 00:00:00 2001 From: Imrayya Date: Wed, 11 Jan 2023 18:00:50 +0100 Subject: [PATCH] Fixed tokenizer being broken in some environments --- scripts/prompt_generator.py | 108 ++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/scripts/prompt_generator.py b/scripts/prompt_generator.py index f81d5e0..ad20fc8 100644 --- a/scripts/prompt_generator.py +++ b/scripts/prompt_generator.py @@ -31,8 +31,60 @@ def get_list_blacklist(): def on_ui_tabs(): - # structure + # Method to create the extended prompt + def generate_longer_prompt(prompt, temperature, top_k, + max_length, repetition_penalty, num_return_sequences, use_blacklist=False): + try: + tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2') + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + model = GPT2LMHeadModel.from_pretrained( + 'FredZhang7/distilgpt2-stable-diffusion-v2') + except Exception as e: + print(f"Exception encountered while attempting to install tokenizer") + return gr.update(), f"Error: {e}" + try: + print(f"Generate new prompt from: \"{prompt}\"") + input_ids = tokenizer(prompt, return_tensors='pt').input_ids + output = model.generate(input_ids, do_sample=True, temperature=temperature, + top_k=top_k, max_length=max_length, + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty, + penalty_alpha=0.6, no_repeat_ngram_size=1, + early_stopping=True) + print("Generation complete!") + tempString = "" + if (use_blacklist): + blacklist = get_list_blacklist() + for i in range(len(output)): + + tempString += str(i+1)+": "+tokenizer.decode( + output[i], skip_special_tokens=True) + "\n" + + if (use_blacklist): + for to_check in blacklist: + tempString = re.sub( + to_check, "", tempString, flags=re.IGNORECASE) + if (i == 0): + global result_prompt + + result_prompt = tempString + print(result_prompt) + + return {results: tempString, + send_to_img2img: gr.update(visible=True), + send_to_txt2img: gr.update(visible=True), + results_col: gr.update(visible=True), + warning: gr.update(visible=True), + promptNum_col: gr.update(visible=True) + } + except Exception as e: + print( + f"Exception encountered while attempting to generate prompt: {e}") + return gr.update(), f"Error: {e}" + + + # structure txt2img_prompt = modules.ui.txt2img_paste_fields[0][0] img2img_prompt = modules.ui.img2img_paste_fields[0][0] @@ -77,58 +129,7 @@ def on_ui_tabs(): send_to_txt2img = gr.Button('Send to txt2img', visible=False) send_to_img2img = gr.Button('Send to img2img', visible=False) - # Method to create the extended prompt - - def generate_longer_prompt(prompt, temperature, top_k, - max_length, repetition_penalty, num_return_sequences, use_blacklist=False): - try: - tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2') - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) - model = GPT2LMHeadModel.from_pretrained( - 'FredZhang7/distilgpt2-stable-diffusion-v2') - except Exception as e: - print(f"Exception encountered while attempting to install tokenizer") - return gr.update(), f"Error: {e}" - try: - print(f"Generate new prompt from: \"{prompt}\"") - input_ids = tokenizer(prompt, return_tensors='pt').input_ids - output = model.generate(input_ids, do_sample=True, temperature=temperature, - top_k=top_k, max_length=max_length, - num_return_sequences=num_return_sequences, - repetition_penalty=repetition_penalty, - penalty_alpha=0.6, no_repeat_ngram_size=1, - early_stopping=True) - print("Generation complete!") - tempString = "" - if (use_blacklist): - blacklist = get_list_blacklist() - for i in range(len(output)): - - tempString += str(i+1)+": "+tokenizer.decode( - output[i], skip_special_tokens=True) + "\n" - - if (use_blacklist): - for to_check in blacklist: - tempString = re.sub( - to_check, "", tempString, flags=re.IGNORECASE) - if (i == 0): - global result_prompt - - result_prompt = tempString - print(result_prompt) - - return {results: tempString, - send_to_img2img: gr.update(visible=True), - send_to_txt2img: gr.update(visible=True), - results_col: gr.update(visible=True), - warning: gr.update(visible=True), - promptNum_col: gr.update(visible=True) - } - except Exception as e: - print( - f"Exception encountered while attempting to generate prompt: {e}") - return gr.update(), f"Error: {e}" - + # events generateButton.click(fn=generate_longer_prompt, inputs=[ promptTxt, temp_slider, top_k_slider, max_length_slider, @@ -146,5 +147,4 @@ def on_ui_tabs(): inputs=None, outputs=None) return (prompt_generator, "Prompt Generator", "Prompt Generator"), - script_callbacks.on_ui_tabs(on_ui_tabs)