diff --git a/install.py b/install.py index 8f2563a..836b382 100644 --- a/install.py +++ b/install.py @@ -20,4 +20,7 @@ if not launch.is_installed('llama-cpp-python'): if os.name == 'nt': launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-win_amd64.whl', 'llama-cpp-python') else: - launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python') \ No newline at end of file + launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python') + +if not launch.is_installed('gpt-stream-json-parser'): + launch.run_pip('install git+https://github.com/furnqse/gpt-stream-json-parser.git', 'gpt-stream-json-parser') \ No newline at end of file diff --git a/scripts/chatgptapi.py b/scripts/chatgptapi.py index 020e2a3..a831597 100644 --- a/scripts/chatgptapi.py +++ b/scripts/chatgptapi.py @@ -5,6 +5,7 @@ import os import openai import sys import json +from gpt_stream_parser import force_parse_json class ChatGptApi: chatgpt_messages = [] @@ -29,6 +30,8 @@ class ChatGptApi: }, }] model = 'gpt-3.5-turbo' + recieved_json = '' + recieved_message = '' def __init__(self, model=None, apikey=None): if model is not None: @@ -81,20 +84,26 @@ class ChatGptApi: self.chatgpt_response = openai.ChatCompletion.create( model=self.model, messages=self.chatgpt_messages, - functions=self.chatgpt_functions + functions=self.chatgpt_functions, + stream=True, ) ignore_result = False - result = str(self.chatgpt_response["choices"][0]["message"]["content"]) + self.recieved_json = '' + self.recieved_message = '' + for chunk in self.chatgpt_response: + if 'function_call' in chunk.choices[0].delta and chunk.choices[0].delta.function_call is not None and 'arguments' in chunk.choices[0].delta.function_call: + self.recieved_json += chunk.choices[0].delta.function_call.arguments + else: + self.recieved_message += chunk.choices[0].delta.get('content', '') + result = self.recieved_message prompt = None - if "function_call" in self.chatgpt_response["choices"][0]["message"].keys(): - function_call = self.chatgpt_response["choices"][0]["message"]["function_call"] - if function_call is not None and function_call["name"] == "txt2img": - func_args = json.loads(function_call["arguments"]) - prompt = func_args["prompt"] - if "message" in func_args: - result = func_args["message"] - else: - ignore_result = True + if self.recieved_json != '': + func_args = json.loads(self.recieved_json) + prompt = func_args["prompt"] + if "message" in func_args: + result = func_args["message"] + else: + ignore_result = True self.chatgpt_response = None if prompt is None: self.chatgpt_messages.append({"role": "assistant", "content": result}) @@ -118,4 +127,13 @@ class ChatGptApi: def clear(self): self.chatgpt_messages = [] self.chatgpt_response = None - self.log_file_name = None \ No newline at end of file + self.log_file_name = None + + def get_stream(self): + if self.recieved_json == '': + return self.recieved_message, None + func_args = force_parse_json(self.recieved_json) + if func_args is not None and "message" in func_args: + return func_args["message"], func_args["prompt"] + else: + return None, None diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index 3d3eac7..16951bc 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -243,4 +243,7 @@ If you understand, please reply to the following:<|end_of_turn|> def clear(self): self.memory.chat_memory.clear() - self.log_file_name = None \ No newline at end of file + self.log_file_name = None + + def get_stream(self): + return None, None \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 2caf0ac..4e87e07 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -9,6 +9,7 @@ import uuid import copy import inspect import sys +import time import gradio as gr from PIL import PngImagePlugin from modules.scripts import basedir @@ -234,9 +235,12 @@ def on_ui_tabs(): images[0].save(last_image_name, pnginfo=(metadata if use_metadata else None)) def append_chat_history(chat_history, text_input_str, result, prompt): - global last_image_name, chat_history_images + global last_image_name, chat_history_images, txt2img_thread if prompt is not None and prompt != '': - chatgpt_txt2img(prompt) + if txt2img_thread is None: + chatgpt_txt2img(prompt) + else: + txt2img_thread.join() if result is None: chat_history_images[len(chat_history)] = last_image_name chat_history.append((text_input_str, (last_image_name, ))) @@ -248,12 +252,54 @@ def on_ui_tabs(): chat_history.append((text_input_str, result)) return chat_history - def chatgpt_generate(text_input_str: str, chat_history): - result, prompt = chat_gpt_api.send(text_input_str) + def recv_stream(chat_history): + global txt2img_thread, stop_recv_thread + txt2img_thread = None + while not stop_recv_thread: + message, prompt = chat_gpt_api.get_stream() + if prompt is not None and prompt != '' and txt2img_thread is None: + txt2img_thread = threading.Thread(target=chatgpt_txt2img, args=(prompt, )) + txt2img_thread.start() + if message is not None: + chat_history[-1][1] = message + time.sleep(0.01) - chat_history = append_chat_history(chat_history, text_input_str, result, prompt) + def chatgpt_generate(chat_history): + global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt + stop_recv_thread = False - return [last_image_name, info_html, comments_html, info_html.replace('
', '\n').replace('

', '').replace('

', '\n').replace('<', '<').replace('>', '>'), '', chat_history] + text_input_str = chat_history[-1][0] + + def recv_thread_func(): + recv_stream(chat_history) + thread = threading.Thread(target=recv_thread_func) + thread.start() + + chatgpt_generate_result = None + chatgpt_generate_prompt = None + def send_thread_func(): + global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt + chatgpt_generate_result, chatgpt_generate_prompt = chat_gpt_api.send(text_input_str) + stop_recv_thread = True + thread2 = threading.Thread(target=send_thread_func) + thread2.start() + + prev_message = chat_history[-1][1] + while not stop_recv_thread: + if prev_message != chat_history[-1][1]: + prev_message = chat_history[-1][1] + yield chat_history + time.sleep(0.01) + + thread.join() + + chat_history = chat_history[:-1] + chat_history = append_chat_history(chat_history, text_input_str, chatgpt_generate_result, chatgpt_generate_prompt) + + yield chat_history + + def chatgpt_generate_finished(): + return [last_image_name, info_html, comments_html, info_html.replace('
', '\n').replace('

', '').replace('

', '\n').replace('<', '<').replace('>', '>'), ''] def chatgpt_remove_last(text_input_str: str, chat_history): if chat_history is None or len(chat_history) <= 0: @@ -463,9 +509,21 @@ def on_ui_tabs(): txt2img_params_base = json.loads(txt2img_params_json) btn_settings_reflect.click(fn=json_reflect, inputs=txt_json_settings) - btn_generate.click(fn=chatgpt_generate, - inputs=[text_input, chatbot], - outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input, chatbot]) + btn_generate.click( + fn=lambda t, c: ['', c + [(t, None)]], + inputs=[text_input, chatbot], + outputs=[text_input, chatbot], + #queue=False, + ).then( + fn=chatgpt_generate, + inputs=chatbot, + outputs=chatbot, + #queue=False, + ).then( + fn=chatgpt_generate_finished, + outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input], + #queue=False, + ) btn_regenerate.click(fn=chatgpt_regenerate, inputs=chatbot, outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, chatbot])