streamの一部実装
parent
37927b9fb9
commit
edd649a43d
|
|
@ -20,4 +20,7 @@ if not launch.is_installed('llama-cpp-python'):
|
|||
if os.name == 'nt':
|
||||
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-win_amd64.whl', 'llama-cpp-python')
|
||||
else:
|
||||
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
|
||||
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
|
||||
|
||||
if not launch.is_installed('gpt-stream-json-parser'):
|
||||
launch.run_pip('install git+https://github.com/furnqse/gpt-stream-json-parser.git', 'gpt-stream-json-parser')
|
||||
|
|
@ -5,6 +5,7 @@ import os
|
|||
import openai
|
||||
import sys
|
||||
import json
|
||||
from gpt_stream_parser import force_parse_json
|
||||
|
||||
class ChatGptApi:
|
||||
chatgpt_messages = []
|
||||
|
|
@ -29,6 +30,8 @@ class ChatGptApi:
|
|||
},
|
||||
}]
|
||||
model = 'gpt-3.5-turbo'
|
||||
recieved_json = ''
|
||||
recieved_message = ''
|
||||
|
||||
def __init__(self, model=None, apikey=None):
|
||||
if model is not None:
|
||||
|
|
@ -81,20 +84,26 @@ class ChatGptApi:
|
|||
self.chatgpt_response = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=self.chatgpt_messages,
|
||||
functions=self.chatgpt_functions
|
||||
functions=self.chatgpt_functions,
|
||||
stream=True,
|
||||
)
|
||||
ignore_result = False
|
||||
result = str(self.chatgpt_response["choices"][0]["message"]["content"])
|
||||
self.recieved_json = ''
|
||||
self.recieved_message = ''
|
||||
for chunk in self.chatgpt_response:
|
||||
if 'function_call' in chunk.choices[0].delta and chunk.choices[0].delta.function_call is not None and 'arguments' in chunk.choices[0].delta.function_call:
|
||||
self.recieved_json += chunk.choices[0].delta.function_call.arguments
|
||||
else:
|
||||
self.recieved_message += chunk.choices[0].delta.get('content', '')
|
||||
result = self.recieved_message
|
||||
prompt = None
|
||||
if "function_call" in self.chatgpt_response["choices"][0]["message"].keys():
|
||||
function_call = self.chatgpt_response["choices"][0]["message"]["function_call"]
|
||||
if function_call is not None and function_call["name"] == "txt2img":
|
||||
func_args = json.loads(function_call["arguments"])
|
||||
prompt = func_args["prompt"]
|
||||
if "message" in func_args:
|
||||
result = func_args["message"]
|
||||
else:
|
||||
ignore_result = True
|
||||
if self.recieved_json != '':
|
||||
func_args = json.loads(self.recieved_json)
|
||||
prompt = func_args["prompt"]
|
||||
if "message" in func_args:
|
||||
result = func_args["message"]
|
||||
else:
|
||||
ignore_result = True
|
||||
self.chatgpt_response = None
|
||||
if prompt is None:
|
||||
self.chatgpt_messages.append({"role": "assistant", "content": result})
|
||||
|
|
@ -118,4 +127,13 @@ class ChatGptApi:
|
|||
def clear(self):
|
||||
self.chatgpt_messages = []
|
||||
self.chatgpt_response = None
|
||||
self.log_file_name = None
|
||||
self.log_file_name = None
|
||||
|
||||
def get_stream(self):
|
||||
if self.recieved_json == '':
|
||||
return self.recieved_message, None
|
||||
func_args = force_parse_json(self.recieved_json)
|
||||
if func_args is not None and "message" in func_args:
|
||||
return func_args["message"], func_args["prompt"]
|
||||
else:
|
||||
return None, None
|
||||
|
|
|
|||
|
|
@ -243,4 +243,7 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
|
||||
def clear(self):
|
||||
self.memory.chat_memory.clear()
|
||||
self.log_file_name = None
|
||||
self.log_file_name = None
|
||||
|
||||
def get_stream(self):
|
||||
return None, None
|
||||
|
|
@ -9,6 +9,7 @@ import uuid
|
|||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import PngImagePlugin
|
||||
from modules.scripts import basedir
|
||||
|
|
@ -234,9 +235,12 @@ def on_ui_tabs():
|
|||
images[0].save(last_image_name, pnginfo=(metadata if use_metadata else None))
|
||||
|
||||
def append_chat_history(chat_history, text_input_str, result, prompt):
|
||||
global last_image_name, chat_history_images
|
||||
global last_image_name, chat_history_images, txt2img_thread
|
||||
if prompt is not None and prompt != '':
|
||||
chatgpt_txt2img(prompt)
|
||||
if txt2img_thread is None:
|
||||
chatgpt_txt2img(prompt)
|
||||
else:
|
||||
txt2img_thread.join()
|
||||
if result is None:
|
||||
chat_history_images[len(chat_history)] = last_image_name
|
||||
chat_history.append((text_input_str, (last_image_name, )))
|
||||
|
|
@ -248,12 +252,54 @@ def on_ui_tabs():
|
|||
chat_history.append((text_input_str, result))
|
||||
return chat_history
|
||||
|
||||
def chatgpt_generate(text_input_str: str, chat_history):
|
||||
result, prompt = chat_gpt_api.send(text_input_str)
|
||||
def recv_stream(chat_history):
|
||||
global txt2img_thread, stop_recv_thread
|
||||
txt2img_thread = None
|
||||
while not stop_recv_thread:
|
||||
message, prompt = chat_gpt_api.get_stream()
|
||||
if prompt is not None and prompt != '' and txt2img_thread is None:
|
||||
txt2img_thread = threading.Thread(target=chatgpt_txt2img, args=(prompt, ))
|
||||
txt2img_thread.start()
|
||||
if message is not None:
|
||||
chat_history[-1][1] = message
|
||||
time.sleep(0.01)
|
||||
|
||||
chat_history = append_chat_history(chat_history, text_input_str, result, prompt)
|
||||
def chatgpt_generate(chat_history):
|
||||
global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt
|
||||
stop_recv_thread = False
|
||||
|
||||
return [last_image_name, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('<', '<').replace('>', '>'), '', chat_history]
|
||||
text_input_str = chat_history[-1][0]
|
||||
|
||||
def recv_thread_func():
|
||||
recv_stream(chat_history)
|
||||
thread = threading.Thread(target=recv_thread_func)
|
||||
thread.start()
|
||||
|
||||
chatgpt_generate_result = None
|
||||
chatgpt_generate_prompt = None
|
||||
def send_thread_func():
|
||||
global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt
|
||||
chatgpt_generate_result, chatgpt_generate_prompt = chat_gpt_api.send(text_input_str)
|
||||
stop_recv_thread = True
|
||||
thread2 = threading.Thread(target=send_thread_func)
|
||||
thread2.start()
|
||||
|
||||
prev_message = chat_history[-1][1]
|
||||
while not stop_recv_thread:
|
||||
if prev_message != chat_history[-1][1]:
|
||||
prev_message = chat_history[-1][1]
|
||||
yield chat_history
|
||||
time.sleep(0.01)
|
||||
|
||||
thread.join()
|
||||
|
||||
chat_history = chat_history[:-1]
|
||||
chat_history = append_chat_history(chat_history, text_input_str, chatgpt_generate_result, chatgpt_generate_prompt)
|
||||
|
||||
yield chat_history
|
||||
|
||||
def chatgpt_generate_finished():
|
||||
return [last_image_name, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('<', '<').replace('>', '>'), '']
|
||||
|
||||
def chatgpt_remove_last(text_input_str: str, chat_history):
|
||||
if chat_history is None or len(chat_history) <= 0:
|
||||
|
|
@ -463,9 +509,21 @@ def on_ui_tabs():
|
|||
txt2img_params_base = json.loads(txt2img_params_json)
|
||||
btn_settings_reflect.click(fn=json_reflect, inputs=txt_json_settings)
|
||||
|
||||
btn_generate.click(fn=chatgpt_generate,
|
||||
inputs=[text_input, chatbot],
|
||||
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input, chatbot])
|
||||
btn_generate.click(
|
||||
fn=lambda t, c: ['', c + [(t, None)]],
|
||||
inputs=[text_input, chatbot],
|
||||
outputs=[text_input, chatbot],
|
||||
#queue=False,
|
||||
).then(
|
||||
fn=chatgpt_generate,
|
||||
inputs=chatbot,
|
||||
outputs=chatbot,
|
||||
#queue=False,
|
||||
).then(
|
||||
fn=chatgpt_generate_finished,
|
||||
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input],
|
||||
#queue=False,
|
||||
)
|
||||
btn_regenerate.click(fn=chatgpt_regenerate,
|
||||
inputs=chatbot,
|
||||
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, chatbot])
|
||||
|
|
|
|||
Loading…
Reference in New Issue