sd-webui-chatgpt/scripts/main.py

299 lines
13 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
import json
import threading
import uuid
import copy
import inspect
import gradio as gr
from modules.scripts import basedir
from modules.txt2img import txt2img
from modules import script_callbacks, sd_samplers
import modules.scripts
from modules import generation_parameters_copypaste as params_copypaste
from scripts import chatgptapi
info_js = ''
info_html = ''
comments_html = ''
last_prompt = ''
last_seed = -1
last_image = None
last_image_name = None
txt2img_params_json = None
txt2img_params_base = None
def init_txt2img_params():
global txt2img_params_json, txt2img_params_base
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_txt2img.json')) as f:
txt2img_params_json = f.read()
txt2img_params_base = json.loads(txt2img_params_json)
def on_ui_tabs():
global txt2img_params_base
init_txt2img_params()
last_prompt = txt2img_params_base['prompt']
last_seed = txt2img_params_base['seed']
apikey = None
if os.path.isfile(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_api.txt')):
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_api.txt')) as f:
apikey = f.read()
chatgpt_settings = None
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_settings.json')) as f:
chatgpt_settings = json.load(f)
if apikey is None or apikey == '':
chat_gpt_api = chatgptapi.ChatGptApi(chatgpt_settings['model'])
else:
chat_gpt_api = chatgptapi.ChatGptApi(chatgpt_settings['model'], apikey)
def chatgpt_txt2img(request_prompt: str):
txt2img_params = copy.deepcopy(txt2img_params_base)
if txt2img_params['prompt'] == '':
txt2img_params['prompt'] = request_prompt
else:
txt2img_params['prompt'] += ', ' + request_prompt
if isinstance(txt2img_params['sampler_index'], str):
sampler_index = 0
for sampler_loop_index, sampler_loop in enumerate(sd_samplers.samplers):
if sampler_loop.name == txt2img_params['sampler_index']:
sampler_index = sampler_loop_index
txt2img_params['sampler_index'] = sampler_index
if isinstance(txt2img_params['hr_sampler_index'], str):
if 'hr_sampler_index' in txt2img_params.keys():
hr_sampler_index = 0
for sampler_loop_index, sampler_loop in enumerate(sd_samplers.samplers):
if sampler_loop.name == txt2img_params['hr_sampler_index']:
hr_sampler_index = txt2img_params['hr_sampler_index']
txt2img_params['hr_sampler_index'] = hr_sampler_index
else:
txt2img_params['hr_sampler_index'] = 0
last_arg_index = 1
for script in modules.scripts.scripts_txt2img.scripts:
if last_arg_index < script.args_to:
last_arg_index = script.args_to
script_args = [None]*last_arg_index
script_args[0] = 0
with gr.Blocks():
for script in modules.scripts.scripts_txt2img.scripts:
if script.ui(False):
ui_default_values = []
for elem in script.ui(False):
ui_default_values.append(elem.value)
script_args[script.args_from:script.args_to] = ui_default_values
global info_js, info_html, comments_html, last_prompt, last_seed, last_image, last_image_name
txt2img_args_sig = inspect.signature(txt2img)
txt2img_args_sig_pairs = txt2img_args_sig.parameters
txt2img_args_names = txt2img_args_sig_pairs.keys()
txt2img_args_values = txt2img_args_sig_pairs.values()
txt2img_args = []
for loop, name in enumerate(txt2img_args_names):
if name == 'args':
txt2img_args.extend(script_args)
elif name == 'request':
txt2img_args.append(gr.Request())
elif name in txt2img_params:
txt2img_args.append(txt2img_params[name])
else:
if isinstance(list(txt2img_args_values)[loop].annotation, str):
txt2img_args.append('')
elif isinstance(list(txt2img_args_values)[loop].annotation, float):
txt2img_args.append(0.0)
elif isinstance(list(txt2img_args_values)[loop].annotation, int):
txt2img_args.append(0)
elif isinstance(list(txt2img_args_values)[loop].annotation, bool):
txt2img_args.append(False)
else:
txt2img_args.append(None)
images, info_js, info_html, comments_html = txt2img(
*txt2img_args)
last_prompt = txt2img_params['prompt']
last_seed = json.loads(info_js)['seed']
last_image = images[0]
os.makedirs(os.path.join(basedir(), 'outputs', 'chatgpt'), exist_ok=True)
last_image_name = os.path.join(basedir(), 'outputs', 'chatgpt', str(uuid.uuid4()).replace('-', '') + '.png')
images[0].save(last_image_name)
def append_chat_history(chat_history, text_input_str, result, prompt):
global last_image_name
if prompt is not None and prompt != '':
chatgpt_txt2img(prompt)
if result is None:
chat_history.append((text_input_str, (last_image_name, )))
else:
chat_history.append((text_input_str, result))
chat_history.append((None, (last_image_name, )))
else:
chat_history.append((text_input_str, result))
return chat_history
def chatgpt_generate(text_input_str: str, chat_history):
result, prompt = chat_gpt_api.send_to_chatgpt(text_input_str)
chat_history = append_chat_history(chat_history, text_input_str, result, prompt)
return [last_image, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('&lt;', '<').replace('&gt;', '>'), '', chat_history]
def chatgpt_remove_last(text_input_str: str, chat_history):
if chat_history is None or len(chat_history) <= 0:
return [text_input_str, chat_history]
input_text = chat_history[-1][0]
chat_history = chat_history[:-1]
if input_text is None:
input_text = chat_history[-1][0]
chat_history = chat_history[:-1]
chat_gpt_api.remove_last_conversation()
ret_text = text_input_str
if text_input_str is None or text_input_str == '':
ret_text = input_text
return [ret_text, chat_history]
def chatgpt_regenerate(chat_history):
if chat_history is not None and len(chat_history) > 0:
input_text = chat_history[-1][0]
chat_history = chat_history[:-1]
if input_text is None:
input_text = chat_history[-1][0]
chat_history = chat_history[:-1]
chat_gpt_api.remove_last_conversation()
result, prompt = chat_gpt_api.send_to_chatgpt(input_text)
chat_history = append_chat_history(chat_history, input_text, result, prompt)
return [last_image, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('&lt;', '<').replace('&gt;', '>'), chat_history]
def chatgpt_load(file_name: str, chat_history):
if os.path.dirname(file_name) == '':
file_name = os.path.join(basedir(), 'outputs', 'chatgpt', 'chat', file_name)
if os.path.isfile(file_name) == '':
print(file_name + ' is not exists.')
return chat_history
with open(file_name, 'r', encoding='UTF-8') as f:
loaded_json = json.load(f)
chat_history = loaded_json['gradio']
chat_gpt_api.set_log(loaded_json['chatgpt'])
return chat_history
def chatgpt_save(file_name: str, chat_history):
if os.path.dirname(file_name) == '':
os.makedirs(os.path.join(basedir(), 'outputs', 'chatgpt', 'chat'), exist_ok=True)
file_name = os.path.join(basedir(), 'outputs', 'chatgpt', 'chat', file_name)
json_dict = {
'gradio': chat_history,
'chatgpt': chat_gpt_api.get_log()
}
with open(file_name, 'w', encoding='UTF-8') as f:
json.dump(json_dict, f)
print(file_name + ' is saved.')
with gr.Blocks(analytics_enabled=False) as runner_interface:
with gr.Row():
gr.Markdown(value='## Chat')
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot()
text_input = gr.Textbox(lines=2, label='')
with gr.Row():
btn_generate = gr.Button(value='Chat', variant='primary')
btn_regenerate = gr.Button(value='Regenerate')
btn_remove_last = gr.Button(value='Remove last')
with gr.Row():
txt_file_path = gr.Textbox(label='File name or path')
btn_load = gr.Button(value='Load')
btn_load.click(fn=chatgpt_load, inputs=[txt_file_path, chatbot], outputs=chatbot)
btn_save = gr.Button(value='Save')
btn_save.click(fn=chatgpt_save, inputs=[txt_file_path, chatbot])
with gr.Row():
gr.Markdown(value='## Last Image')
with gr.Row():
with gr.Column():
image_gr = gr.Image(type='pil', interactive=False)
with gr.Column():
info_text_gr = gr.Textbox(visible=False, interactive=False)
info_html_gr = gr.HTML(info_html)
comments_html_gr = gr.HTML(comments_html)
send_tabs = ["txt2img", "img2img", "inpaint", "extras"]
with gr.Row():
buttons = params_copypaste.create_buttons(send_tabs)
for send_tab in send_tabs:
params_copypaste.register_paste_params_button(params_copypaste.ParamBinding(
paste_button=buttons[send_tab], tabname=send_tab, source_text_component=info_text_gr, source_image_component=image_gr,
))
with gr.Row():
gr.Markdown(value='## Settings')
with gr.Row():
txt_apikey = gr.Textbox(value=apikey, label='API Key')
btn_apikey_save = gr.Button(value='Save And Reflect', variant='primary')
def apikey_save(setting_api: str):
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_api.txt'), 'w') as f:
f.write(setting_api)
chat_gpt_api.change_apikey(setting_api)
btn_apikey_save.click(fn=apikey_save, inputs=txt_apikey)
with gr.Row():
txt_chatgpt_model = gr.Textbox(value=chatgpt_settings['model'], label='ChatGPT Model Name')
btn_chatgpt_model_save = gr.Button(value='Save And Reflect', variant='primary')
def chatgpt_model_save(setting_model: str):
chatgpt_settings['model'] = setting_model
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_settings.json'), 'w') as f:
json.dump(chatgpt_settings, f)
chat_gpt_api.change_model(setting_model)
btn_chatgpt_model_save.click(fn=chatgpt_model_save, inputs=txt_chatgpt_model)
with gr.Row():
lines = txt2img_params_json.count('\n') + 1
txt_json_settings = gr.Textbox(value=txt2img_params_json, lines=lines, max_lines=lines + 5, label='txt2img')
with gr.Row():
with gr.Column():
btn_settings_save = gr.Button(value='Save', variant='primary')
def json_save(settings: str):
with open(os.path.join(os.path.dirname(__file__), '..', 'settings', 'chatgpt_txt2img.json'), 'w') as f:
f.write(settings)
btn_settings_save.click(fn=json_save, inputs=txt_json_settings)
with gr.Column():
btn_settings_reflect = gr.Button(value='Reflect settings', variant='primary')
def json_reflect(settings: str):
global txt2img_params_json, txt2img_params_base
txt2img_params_json = settings
txt2img_params_base = json.loads(txt2img_params_json)
btn_settings_reflect.click(fn=json_reflect, inputs=txt_json_settings)
btn_generate.click(fn=chatgpt_generate,
inputs=[text_input, chatbot],
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input, chatbot])
btn_regenerate.click(fn=chatgpt_regenerate,
inputs=chatbot,
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, chatbot])
btn_remove_last.click(fn=chatgpt_remove_last,
inputs=[text_input, chatbot],
outputs=[text_input, chatbot])
return [(runner_interface, 'sd-webui-chatgpt', 'chatgpt_interface')]
script_callbacks.on_ui_tabs(on_ui_tabs)