Allow for more arbitrary models

pull/16/head
Imrayya 2023-02-11 12:06:18 +01:00
parent b4833e90ce
commit 952f3fee41
3 changed files with 70 additions and 80 deletions

View File

@ -13,7 +13,8 @@ Adds a tab to the webui that allows the user to generate a prompt from a small b
## Usage
1. Write in the prompt in the *Start of the prompt* text box
2. Click Generate and wait
2. Select which model you want to use
3. Click Generate and wait
The initial use of the model may take longer as it needs to be downloaded to your machine for offline use. The model will be used on your device and will be stored in the default location of `*username*/.cache/huggingface/hub/models`. The entire process of generating results will be done on your local machine and not require internet access.
@ -25,11 +26,12 @@ The initial use of the model may take longer as it needs to be downloaded to you
- **Max Length**: the maximum number of tokens for the output of the model
- **Repetition Penalty**: The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. Default setting is 1.2
- **How Many To Generate**: The number of results to generate
- **Use blacklist?** Using `.\extensions\stable-diffusion-webui-Prompt_Generator\blacklist.txt`. It will delete any matches to the generated result (case insensitive). Each item to be filtered out should be on a new line. *Be aware that it simply deletes it and doesn't generate more to make up for the lost words*
- **Use blacklist?**: Using `.\extensions\stable-diffusion-webui-Prompt_Generator\blacklist.txt`. It will delete any matches to the generated result (case insensitive). Each item to be filtered out should be on a new line. *Be aware that it simply deletes it and doesn't generate more to make up for the lost words*
- **Use puncation**: Allows the use commas in the output
## Models
There are two models provided:
There are two 'default' models provided:
### FredZhang7
@ -45,6 +47,12 @@ Useful to get more natural language prompts. Eg: "A cat sitting" -> "A cat sitti
*Be aware that sometimes the model fails to produce anything or less than the wanted amount, either try again or use a new prompt in that case*
## Install more models
To install more model to use, ensure that the models are hosted on [huggingface.co](https://huggingface.co) and edit the json file at `.\extensions\stable-diffusion-webui-Prompt_Generator\models.json` with the relevant information. Use the models in the file as an basis
You might need to restart the extension/reload the UI if new items are added onto the list
## Credits
Credits to both [FredZhang7](https://huggingface.co/FredZhang7) and [Gustavosta](https://huggingface.co/Gustavosta)

12
models.json Normal file
View File

@ -0,0 +1,12 @@
[
{
"Title":"Gustavosta",
"Tokenizer":"gpt2",
"Model":"Gustavosta/MagicPrompt-Dalle"
},
{
"Title":"FredZhang7",
"Tokenizer":"distilgpt2",
"Model":"FredZhang7/distilgpt2-stable-diffusion-v2"
}
]

View File

@ -15,8 +15,30 @@ import modules
from modules import script_callbacks
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import re
import math
import json
result_prompt = ""
models = {}
class Model:
def __init__(self, name, model, tokenizer) -> None:
self.name = name
self.model = model
self.tokenizer = tokenizer
pass
def populate_models():
path = "./extensions/stable-diffusion-webui-Prompt_Generator/models.json"
with open(path, 'r') as f:
data = json.load(f)
for item in data:
name = item["Title"]
model = item["Model"]
tokenizer = item["Tokenizer"]
models[name] = Model(name, model, tokenizer)
def add_to_prompt(num): # A function that determines which prompt to pass
@ -44,81 +66,32 @@ def get_list_blacklist():
def on_ui_tabs():
# Method to create the extended prompt
def generate_longer_prompt_gustavosta(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, use_blacklist=False, use_early_stop=True):
def generate_longer_generic(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, name, use_punctuation=False, use_blacklist=False):
try:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Full credits for the model to Gustavosta (https://huggingface.co/Gustavosta). Under the MIT license
model = GPT2LMHeadModel.from_pretrained(
'Gustavosta/MagicPrompt-Dalle')
except Exception as e:
print(f"Exception encountered while attempting to install tokenizer")
return gr.update(), f"Error: {e}"
try:
min = len(prompt)
print(f"Generate new prompt from: \"{prompt}\"")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=round(top_k), max_length=max_length,
num_return_sequences=num_return_sequences*4,
repetition_penalty=float(repetition_penalty),
penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=use_early_stop)
print("Generation complete!")
tempString = ""
if (use_blacklist):
blacklist = get_list_blacklist()
j = 0
for i in range(len(output)):
tempt_of_temp_String = tokenizer.decode(
output[i], skip_special_tokens=True)
if (len(tempt_of_temp_String) > min + 4):
tempString += str(j+1) + ": " + tempt_of_temp_String
j += 1
else:
continue
if (use_blacklist):
for to_check in blacklist:
tempString = re.sub(
to_check, "", tempString, flags=re.IGNORECASE)
if (j == num_return_sequences):
break
global result_prompt
result_prompt = tempString
return {results: tempString,
send_to_img2img: gr.update(visible=True),
send_to_txt2img: gr.update(visible=True),
send_to_text: gr.update(visible=True),
results_col: gr.update(visible=True),
warning: gr.update(visible=True),
promptNum_col: gr.update(visible=True)
}
except Exception as e:
print(
f"Exception encountered while attempting to generate prompt: {e}")
return gr.update(), f"Error: {e}"
def generate_longer_prompt_FredZhang7(prompt, temperature, top_k,
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
try:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer = GPT2Tokenizer.from_pretrained(models[name].tokenizer)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Full credits for the model to FredZhang7 (https://huggingface.co/FredZhang7). Under creativeml-openrail-m license.
model = GPT2LMHeadModel.from_pretrained(
'FredZhang7/distilgpt2-stable-diffusion-v2')
model = GPT2LMHeadModel.from_pretrained(models[name].model)
except Exception as e:
print(f"Exception encountered while attempting to install tokenizer")
return gr.update(), f"Error: {e}"
try:
print(f"Generate new prompt from: \"{prompt}\"")
print(f"Generate new prompt from: \"{prompt}\" with {name}")
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature,
if(use_punctuation):
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=round(top_k), max_length=max_length,
num_return_sequences=num_return_sequences,
repetition_penalty=float(repetition_penalty),
repetition_penalty=float(
repetition_penalty),
early_stopping=True)
else:
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=round(top_k), max_length=max_length,
num_return_sequences=num_return_sequences,
repetition_penalty=float(
repetition_penalty),
penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=True)
print("Generation complete!")
@ -163,6 +136,8 @@ def on_ui_tabs():
promptTxt = gr.Textbox(
lines=2, elem_id="promptTxt", label="Start of the prompt")
with gr.Column():
gr.HTML(
"Mouse over the labels to access tooltips that provide explanations for the parameters.")
with gr.Row():
temp_slider = gr.Slider(
elem_id="temp_slider", label="Temperature", interactive=True, minimum=0, maximum=1, value=0.9)
@ -182,10 +157,11 @@ def on_ui_tabs():
gr.HTML(value="<center>Using <code>\".\extensions\stable-diffusion-webui-Prompt_Generator\\blacklist.txt</code>\".<br>It will delete any matches to the generated result (case insensitive).</center>")
with gr.Column():
with gr.Row():
populate_models()
generate_dropdown = gr.Dropdown(choices=list(models.keys()), value="FredZhang7", label = "Which model to use?",show_label=True)
use_puncation_check = gr.Checkbox(label="Use puncation?")
generateButton_fred = gr.Button(
value="Generate Using FredZhang7", elem_id="generate_button_FredZhang7")
generateButton_magic = gr.Button(
value="Generate Using Magic Prompt", elem_id="generate_button_MagicPrompt")
value="Generate", elem_id="generate_button")
with gr.Column(visible=False) as results_col:
results = gr.Text(
label="Results", elem_id="Results_textBox", interactive=False)
@ -203,16 +179,10 @@ def on_ui_tabs():
'Send to back to prompter', visible=False)
# events
generateButton_fred.click(fn=generate_longer_prompt_FredZhang7, inputs=[
generateButton_fred.click(fn=generate_longer_generic, inputs=[
promptTxt, temp_slider, top_k_slider, max_length_slider,
repetition_penalty_slider, num_return_sequences_slider,
use_blacklist_checkbox],
outputs=[results, send_to_img2img, send_to_txt2img, send_to_text,
results_col, warning, promptNum_col])
generateButton_magic.click(fn=generate_longer_prompt_gustavosta, inputs=[
promptTxt, temp_slider, top_k_slider, max_length_slider,
repetition_penalty_slider, num_return_sequences_slider,
use_blacklist_checkbox],
generate_dropdown,use_puncation_check, use_blacklist_checkbox],
outputs=[results, send_to_img2img, send_to_txt2img, send_to_text,
results_col, warning, promptNum_col])
send_to_img2img.click(add_to_prompt, inputs=[