diff --git a/install.py b/install.py
index 8f2563a..836b382 100644
--- a/install.py
+++ b/install.py
@@ -20,4 +20,7 @@ if not launch.is_installed('llama-cpp-python'):
if os.name == 'nt':
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-win_amd64.whl', 'llama-cpp-python')
else:
- launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
\ No newline at end of file
+ launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
+
+if not launch.is_installed('gpt-stream-json-parser'):
+ launch.run_pip('install git+https://github.com/furnqse/gpt-stream-json-parser.git', 'gpt-stream-json-parser')
\ No newline at end of file
diff --git a/scripts/chatgptapi.py b/scripts/chatgptapi.py
index 020e2a3..a831597 100644
--- a/scripts/chatgptapi.py
+++ b/scripts/chatgptapi.py
@@ -5,6 +5,7 @@ import os
import openai
import sys
import json
+from gpt_stream_parser import force_parse_json
class ChatGptApi:
chatgpt_messages = []
@@ -29,6 +30,8 @@ class ChatGptApi:
},
}]
model = 'gpt-3.5-turbo'
+ recieved_json = ''
+ recieved_message = ''
def __init__(self, model=None, apikey=None):
if model is not None:
@@ -81,20 +84,26 @@ class ChatGptApi:
self.chatgpt_response = openai.ChatCompletion.create(
model=self.model,
messages=self.chatgpt_messages,
- functions=self.chatgpt_functions
+ functions=self.chatgpt_functions,
+ stream=True,
)
ignore_result = False
- result = str(self.chatgpt_response["choices"][0]["message"]["content"])
+ self.recieved_json = ''
+ self.recieved_message = ''
+ for chunk in self.chatgpt_response:
+ if 'function_call' in chunk.choices[0].delta and chunk.choices[0].delta.function_call is not None and 'arguments' in chunk.choices[0].delta.function_call:
+ self.recieved_json += chunk.choices[0].delta.function_call.arguments
+ else:
+ self.recieved_message += chunk.choices[0].delta.get('content', '')
+ result = self.recieved_message
prompt = None
- if "function_call" in self.chatgpt_response["choices"][0]["message"].keys():
- function_call = self.chatgpt_response["choices"][0]["message"]["function_call"]
- if function_call is not None and function_call["name"] == "txt2img":
- func_args = json.loads(function_call["arguments"])
- prompt = func_args["prompt"]
- if "message" in func_args:
- result = func_args["message"]
- else:
- ignore_result = True
+ if self.recieved_json != '':
+ func_args = json.loads(self.recieved_json)
+ prompt = func_args["prompt"]
+ if "message" in func_args:
+ result = func_args["message"]
+ else:
+ ignore_result = True
self.chatgpt_response = None
if prompt is None:
self.chatgpt_messages.append({"role": "assistant", "content": result})
@@ -118,4 +127,13 @@ class ChatGptApi:
def clear(self):
self.chatgpt_messages = []
self.chatgpt_response = None
- self.log_file_name = None
\ No newline at end of file
+ self.log_file_name = None
+
+ def get_stream(self):
+ if self.recieved_json == '':
+ return self.recieved_message, None
+ func_args = force_parse_json(self.recieved_json)
+ if func_args is not None and "message" in func_args:
+ return func_args["message"], func_args["prompt"]
+ else:
+ return None, None
diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py
index 3d3eac7..16951bc 100644
--- a/scripts/langchainapi.py
+++ b/scripts/langchainapi.py
@@ -243,4 +243,7 @@ If you understand, please reply to the following:<|end_of_turn|>
def clear(self):
self.memory.chat_memory.clear()
- self.log_file_name = None
\ No newline at end of file
+ self.log_file_name = None
+
+ def get_stream(self):
+ return None, None
\ No newline at end of file
diff --git a/scripts/main.py b/scripts/main.py
index 2caf0ac..4e87e07 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -9,6 +9,7 @@ import uuid
import copy
import inspect
import sys
+import time
import gradio as gr
from PIL import PngImagePlugin
from modules.scripts import basedir
@@ -234,9 +235,12 @@ def on_ui_tabs():
images[0].save(last_image_name, pnginfo=(metadata if use_metadata else None))
def append_chat_history(chat_history, text_input_str, result, prompt):
- global last_image_name, chat_history_images
+ global last_image_name, chat_history_images, txt2img_thread
if prompt is not None and prompt != '':
- chatgpt_txt2img(prompt)
+ if txt2img_thread is None:
+ chatgpt_txt2img(prompt)
+ else:
+ txt2img_thread.join()
if result is None:
chat_history_images[len(chat_history)] = last_image_name
chat_history.append((text_input_str, (last_image_name, )))
@@ -248,12 +252,54 @@ def on_ui_tabs():
chat_history.append((text_input_str, result))
return chat_history
- def chatgpt_generate(text_input_str: str, chat_history):
- result, prompt = chat_gpt_api.send(text_input_str)
+ def recv_stream(chat_history):
+ global txt2img_thread, stop_recv_thread
+ txt2img_thread = None
+ while not stop_recv_thread:
+ message, prompt = chat_gpt_api.get_stream()
+ if prompt is not None and prompt != '' and txt2img_thread is None:
+ txt2img_thread = threading.Thread(target=chatgpt_txt2img, args=(prompt, ))
+ txt2img_thread.start()
+ if message is not None:
+ chat_history[-1][1] = message
+ time.sleep(0.01)
- chat_history = append_chat_history(chat_history, text_input_str, result, prompt)
+ def chatgpt_generate(chat_history):
+ global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt
+ stop_recv_thread = False
- return [last_image_name, info_html, comments_html, info_html.replace('
', '\n').replace('
', '').replace('
', '\n').replace('<', '<').replace('>', '>'), '', chat_history] + text_input_str = chat_history[-1][0] + + def recv_thread_func(): + recv_stream(chat_history) + thread = threading.Thread(target=recv_thread_func) + thread.start() + + chatgpt_generate_result = None + chatgpt_generate_prompt = None + def send_thread_func(): + global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt + chatgpt_generate_result, chatgpt_generate_prompt = chat_gpt_api.send(text_input_str) + stop_recv_thread = True + thread2 = threading.Thread(target=send_thread_func) + thread2.start() + + prev_message = chat_history[-1][1] + while not stop_recv_thread: + if prev_message != chat_history[-1][1]: + prev_message = chat_history[-1][1] + yield chat_history + time.sleep(0.01) + + thread.join() + + chat_history = chat_history[:-1] + chat_history = append_chat_history(chat_history, text_input_str, chatgpt_generate_result, chatgpt_generate_prompt) + + yield chat_history + + def chatgpt_generate_finished(): + return [last_image_name, info_html, comments_html, info_html.replace('', '').replace('
', '\n').replace('<', '<').replace('>', '>'), ''] def chatgpt_remove_last(text_input_str: str, chat_history): if chat_history is None or len(chat_history) <= 0: @@ -463,9 +509,21 @@ def on_ui_tabs(): txt2img_params_base = json.loads(txt2img_params_json) btn_settings_reflect.click(fn=json_reflect, inputs=txt_json_settings) - btn_generate.click(fn=chatgpt_generate, - inputs=[text_input, chatbot], - outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input, chatbot]) + btn_generate.click( + fn=lambda t, c: ['', c + [(t, None)]], + inputs=[text_input, chatbot], + outputs=[text_input, chatbot], + #queue=False, + ).then( + fn=chatgpt_generate, + inputs=chatbot, + outputs=chatbot, + #queue=False, + ).then( + fn=chatgpt_generate_finished, + outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input], + #queue=False, + ) btn_regenerate.click(fn=chatgpt_regenerate, inputs=chatbot, outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, chatbot])