mirror of https://github.com/vladmandic/automatic
58 lines
2.9 KiB
Python
Executable File
58 lines
2.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
generate prompt ideas
|
|
model from: <https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2>
|
|
"""
|
|
|
|
import logging
|
|
import argparse
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
from util import log
|
|
|
|
|
|
def prompt(text: str, temp: float = 0.9, top: int = 8, penalty: float = 1.2, alpha: float = 0.6, num: int = 5, length: int = 80):
|
|
log.info({ 'loading': 'tokenizer' })
|
|
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
input_ids = tokenizer(text, return_tensors='pt').input_ids
|
|
log.info({ 'loading': 'model' })
|
|
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
|
|
output = model.generate(input_ids,
|
|
do_sample = True,
|
|
temperature = temp,
|
|
top_k = top,
|
|
max_length = length,
|
|
num_return_sequences = num,
|
|
repetition_penalty = penalty,
|
|
penalty_alpha = alpha,
|
|
no_repeat_ngram_size = 1,
|
|
early_stopping = True
|
|
)
|
|
outputs = []
|
|
for i in range(len(output)):
|
|
outputs.append(tokenizer.decode(output[i], skip_special_tokens=True))
|
|
return outputs
|
|
|
|
|
|
if __name__ == "__main__": # create & train test embedding when used from cli
|
|
log.info({ 'idea': 'generate prompts' })
|
|
parser = argparse.ArgumentParser(description='idea: generate prompts')
|
|
parser.add_argument("--temp", type = float, default = 0.9, required = False, help = "higher temperature produces more diverse results with a higher risk of less coherent text, default: %(default)s")
|
|
parser.add_argument("--top", type = int, default = 8, required = False, help = "number of tokens to sample from at each step, default: %(default)s")
|
|
parser.add_argument("--penalty", type = float, default = 1.2, required = False, help = "penalty value for each repetition of a token, default: %(default)s")
|
|
parser.add_argument("--alpha", type = float, default = 0.6, required = False, help = "penalty alpha value, default: %(default)s")
|
|
parser.add_argument("--num", type = int, default = 10, required = False, help = "number of results to generate, default: %(default)s")
|
|
parser.add_argument("--length", type = int, default = 85, required = False, help = "maximum number of output tokens, default: %(default)s")
|
|
parser.add_argument('--debug', default = False, action='store_true', help = "print extra debug information, default: %(default)s")
|
|
parser.add_argument('text', type = str, nargs = '*')
|
|
params = parser.parse_args()
|
|
if params.debug:
|
|
log.setLevel(logging.DEBUG)
|
|
log.debug({ 'debug': True })
|
|
log.debug({ 'args': params.__dict__ })
|
|
sentence = ' '.join(params.text)
|
|
res = prompt(text = sentence, temp = params.temp, top = params.top, penalty = params.penalty, alpha = params.alpha, num = params.num, length = params.length)
|
|
log.info({ 'ideas for': sentence })
|
|
for line in res:
|
|
log.info(line)
|