Allow for more arbitrary models
parent
b4833e90ce
commit
952f3fee41
14
README.md
14
README.md
|
|
@ -13,7 +13,8 @@ Adds a tab to the webui that allows the user to generate a prompt from a small b
|
|||
## Usage
|
||||
|
||||
1. Write in the prompt in the *Start of the prompt* text box
|
||||
2. Click Generate and wait
|
||||
2. Select which model you want to use
|
||||
3. Click Generate and wait
|
||||
|
||||
The initial use of the model may take longer as it needs to be downloaded to your machine for offline use. The model will be used on your device and will be stored in the default location of `*username*/.cache/huggingface/hub/models`. The entire process of generating results will be done on your local machine and not require internet access.
|
||||
|
||||
|
|
@ -25,11 +26,12 @@ The initial use of the model may take longer as it needs to be downloaded to you
|
|||
- **Max Length**: the maximum number of tokens for the output of the model
|
||||
- **Repetition Penalty**: The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. Default setting is 1.2
|
||||
- **How Many To Generate**: The number of results to generate
|
||||
- **Use blacklist?** Using `.\extensions\stable-diffusion-webui-Prompt_Generator\blacklist.txt`. It will delete any matches to the generated result (case insensitive). Each item to be filtered out should be on a new line. *Be aware that it simply deletes it and doesn't generate more to make up for the lost words*
|
||||
- **Use blacklist?**: Using `.\extensions\stable-diffusion-webui-Prompt_Generator\blacklist.txt`. It will delete any matches to the generated result (case insensitive). Each item to be filtered out should be on a new line. *Be aware that it simply deletes it and doesn't generate more to make up for the lost words*
|
||||
- **Use puncation**: Allows the use commas in the output
|
||||
|
||||
## Models
|
||||
|
||||
There are two models provided:
|
||||
There are two 'default' models provided:
|
||||
|
||||
### FredZhang7
|
||||
|
||||
|
|
@ -45,6 +47,12 @@ Useful to get more natural language prompts. Eg: "A cat sitting" -> "A cat sitti
|
|||
|
||||
*Be aware that sometimes the model fails to produce anything or less than the wanted amount, either try again or use a new prompt in that case*
|
||||
|
||||
## Install more models
|
||||
|
||||
To install more model to use, ensure that the models are hosted on [huggingface.co](https://huggingface.co) and edit the json file at `.\extensions\stable-diffusion-webui-Prompt_Generator\models.json` with the relevant information. Use the models in the file as an basis
|
||||
|
||||
You might need to restart the extension/reload the UI if new items are added onto the list
|
||||
|
||||
## Credits
|
||||
|
||||
Credits to both [FredZhang7](https://huggingface.co/FredZhang7) and [Gustavosta](https://huggingface.co/Gustavosta)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
[
|
||||
{
|
||||
"Title":"Gustavosta",
|
||||
"Tokenizer":"gpt2",
|
||||
"Model":"Gustavosta/MagicPrompt-Dalle"
|
||||
},
|
||||
{
|
||||
"Title":"FredZhang7",
|
||||
"Tokenizer":"distilgpt2",
|
||||
"Model":"FredZhang7/distilgpt2-stable-diffusion-v2"
|
||||
}
|
||||
]
|
||||
|
|
@ -15,8 +15,30 @@ import modules
|
|||
from modules import script_callbacks
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import re
|
||||
import math
|
||||
import json
|
||||
|
||||
result_prompt = ""
|
||||
models = {}
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, name, model, tokenizer) -> None:
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
pass
|
||||
|
||||
|
||||
def populate_models():
|
||||
path = "./extensions/stable-diffusion-webui-Prompt_Generator/models.json"
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
for item in data:
|
||||
name = item["Title"]
|
||||
model = item["Model"]
|
||||
tokenizer = item["Tokenizer"]
|
||||
models[name] = Model(name, model, tokenizer)
|
||||
|
||||
|
||||
|
||||
def add_to_prompt(num): # A function that determines which prompt to pass
|
||||
|
|
@ -44,81 +66,32 @@ def get_list_blacklist():
|
|||
def on_ui_tabs():
|
||||
# Method to create the extended prompt
|
||||
|
||||
def generate_longer_prompt_gustavosta(prompt, temperature, top_k,
|
||||
max_length, repetition_penalty, num_return_sequences, use_blacklist=False, use_early_stop=True):
|
||||
def generate_longer_generic(prompt, temperature, top_k,
|
||||
max_length, repetition_penalty, num_return_sequences, name, use_punctuation=False, use_blacklist=False):
|
||||
try:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
# Full credits for the model to Gustavosta (https://huggingface.co/Gustavosta). Under the MIT license
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
'Gustavosta/MagicPrompt-Dalle')
|
||||
except Exception as e:
|
||||
print(f"Exception encountered while attempting to install tokenizer")
|
||||
return gr.update(), f"Error: {e}"
|
||||
try:
|
||||
min = len(prompt)
|
||||
print(f"Generate new prompt from: \"{prompt}\"")
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
||||
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
||||
top_k=round(top_k), max_length=max_length,
|
||||
num_return_sequences=num_return_sequences*4,
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
penalty_alpha=0.6, no_repeat_ngram_size=1,
|
||||
early_stopping=use_early_stop)
|
||||
print("Generation complete!")
|
||||
tempString = ""
|
||||
if (use_blacklist):
|
||||
blacklist = get_list_blacklist()
|
||||
j = 0
|
||||
for i in range(len(output)):
|
||||
tempt_of_temp_String = tokenizer.decode(
|
||||
output[i], skip_special_tokens=True)
|
||||
if (len(tempt_of_temp_String) > min + 4):
|
||||
tempString += str(j+1) + ": " + tempt_of_temp_String
|
||||
j += 1
|
||||
else:
|
||||
continue
|
||||
if (use_blacklist):
|
||||
for to_check in blacklist:
|
||||
tempString = re.sub(
|
||||
to_check, "", tempString, flags=re.IGNORECASE)
|
||||
if (j == num_return_sequences):
|
||||
break
|
||||
|
||||
global result_prompt
|
||||
result_prompt = tempString
|
||||
|
||||
return {results: tempString,
|
||||
send_to_img2img: gr.update(visible=True),
|
||||
send_to_txt2img: gr.update(visible=True),
|
||||
send_to_text: gr.update(visible=True),
|
||||
results_col: gr.update(visible=True),
|
||||
warning: gr.update(visible=True),
|
||||
promptNum_col: gr.update(visible=True)
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Exception encountered while attempting to generate prompt: {e}")
|
||||
return gr.update(), f"Error: {e}"
|
||||
|
||||
def generate_longer_prompt_FredZhang7(prompt, temperature, top_k,
|
||||
max_length, repetition_penalty, num_return_sequences, use_blacklist=False):
|
||||
try:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(models[name].tokenizer)
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
# Full credits for the model to FredZhang7 (https://huggingface.co/FredZhang7). Under creativeml-openrail-m license.
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
'FredZhang7/distilgpt2-stable-diffusion-v2')
|
||||
model = GPT2LMHeadModel.from_pretrained(models[name].model)
|
||||
except Exception as e:
|
||||
print(f"Exception encountered while attempting to install tokenizer")
|
||||
return gr.update(), f"Error: {e}"
|
||||
try:
|
||||
print(f"Generate new prompt from: \"{prompt}\"")
|
||||
print(f"Generate new prompt from: \"{prompt}\" with {name}")
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
||||
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
||||
if(use_punctuation):
|
||||
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
||||
top_k=round(top_k), max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
repetition_penalty=float(
|
||||
repetition_penalty),
|
||||
early_stopping=True)
|
||||
else:
|
||||
output = model.generate(input_ids, do_sample=True, temperature=temperature,
|
||||
top_k=round(top_k), max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=float(
|
||||
repetition_penalty),
|
||||
penalty_alpha=0.6, no_repeat_ngram_size=1,
|
||||
early_stopping=True)
|
||||
print("Generation complete!")
|
||||
|
|
@ -163,6 +136,8 @@ def on_ui_tabs():
|
|||
promptTxt = gr.Textbox(
|
||||
lines=2, elem_id="promptTxt", label="Start of the prompt")
|
||||
with gr.Column():
|
||||
gr.HTML(
|
||||
"Mouse over the labels to access tooltips that provide explanations for the parameters.")
|
||||
with gr.Row():
|
||||
temp_slider = gr.Slider(
|
||||
elem_id="temp_slider", label="Temperature", interactive=True, minimum=0, maximum=1, value=0.9)
|
||||
|
|
@ -182,10 +157,11 @@ def on_ui_tabs():
|
|||
gr.HTML(value="<center>Using <code>\".\extensions\stable-diffusion-webui-Prompt_Generator\\blacklist.txt</code>\".<br>It will delete any matches to the generated result (case insensitive).</center>")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
populate_models()
|
||||
generate_dropdown = gr.Dropdown(choices=list(models.keys()), value="FredZhang7", label = "Which model to use?",show_label=True)
|
||||
use_puncation_check = gr.Checkbox(label="Use puncation?")
|
||||
generateButton_fred = gr.Button(
|
||||
value="Generate Using FredZhang7", elem_id="generate_button_FredZhang7")
|
||||
generateButton_magic = gr.Button(
|
||||
value="Generate Using Magic Prompt", elem_id="generate_button_MagicPrompt")
|
||||
value="Generate", elem_id="generate_button")
|
||||
with gr.Column(visible=False) as results_col:
|
||||
results = gr.Text(
|
||||
label="Results", elem_id="Results_textBox", interactive=False)
|
||||
|
|
@ -203,16 +179,10 @@ def on_ui_tabs():
|
|||
'Send to back to prompter', visible=False)
|
||||
|
||||
# events
|
||||
generateButton_fred.click(fn=generate_longer_prompt_FredZhang7, inputs=[
|
||||
generateButton_fred.click(fn=generate_longer_generic, inputs=[
|
||||
promptTxt, temp_slider, top_k_slider, max_length_slider,
|
||||
repetition_penalty_slider, num_return_sequences_slider,
|
||||
use_blacklist_checkbox],
|
||||
outputs=[results, send_to_img2img, send_to_txt2img, send_to_text,
|
||||
results_col, warning, promptNum_col])
|
||||
generateButton_magic.click(fn=generate_longer_prompt_gustavosta, inputs=[
|
||||
promptTxt, temp_slider, top_k_slider, max_length_slider,
|
||||
repetition_penalty_slider, num_return_sequences_slider,
|
||||
use_blacklist_checkbox],
|
||||
generate_dropdown,use_puncation_check, use_blacklist_checkbox],
|
||||
outputs=[results, send_to_img2img, send_to_txt2img, send_to_text,
|
||||
results_col, warning, promptNum_col])
|
||||
send_to_img2img.click(add_to_prompt, inputs=[
|
||||
|
|
|
|||
Loading…
Reference in New Issue