streamの一部実装

hr_error
NON906 2024-01-10 21:24:47 +09:00
parent 37927b9fb9
commit edd649a43d
4 changed files with 105 additions and 23 deletions

View File

@ -20,4 +20,7 @@ if not launch.is_installed('llama-cpp-python'):
if os.name == 'nt':
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-win_amd64.whl', 'llama-cpp-python')
else:
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
launch.run_pip('install https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/wheels/llama_cpp_python-0.2.23+cu118-cp310-cp310-manylinux_2_31_x86_64.whl', 'llama-cpp-python')
if not launch.is_installed('gpt-stream-json-parser'):
launch.run_pip('install git+https://github.com/furnqse/gpt-stream-json-parser.git', 'gpt-stream-json-parser')

View File

@ -5,6 +5,7 @@ import os
import openai
import sys
import json
from gpt_stream_parser import force_parse_json
class ChatGptApi:
chatgpt_messages = []
@ -29,6 +30,8 @@ class ChatGptApi:
},
}]
model = 'gpt-3.5-turbo'
recieved_json = ''
recieved_message = ''
def __init__(self, model=None, apikey=None):
if model is not None:
@ -81,20 +84,26 @@ class ChatGptApi:
self.chatgpt_response = openai.ChatCompletion.create(
model=self.model,
messages=self.chatgpt_messages,
functions=self.chatgpt_functions
functions=self.chatgpt_functions,
stream=True,
)
ignore_result = False
result = str(self.chatgpt_response["choices"][0]["message"]["content"])
self.recieved_json = ''
self.recieved_message = ''
for chunk in self.chatgpt_response:
if 'function_call' in chunk.choices[0].delta and chunk.choices[0].delta.function_call is not None and 'arguments' in chunk.choices[0].delta.function_call:
self.recieved_json += chunk.choices[0].delta.function_call.arguments
else:
self.recieved_message += chunk.choices[0].delta.get('content', '')
result = self.recieved_message
prompt = None
if "function_call" in self.chatgpt_response["choices"][0]["message"].keys():
function_call = self.chatgpt_response["choices"][0]["message"]["function_call"]
if function_call is not None and function_call["name"] == "txt2img":
func_args = json.loads(function_call["arguments"])
prompt = func_args["prompt"]
if "message" in func_args:
result = func_args["message"]
else:
ignore_result = True
if self.recieved_json != '':
func_args = json.loads(self.recieved_json)
prompt = func_args["prompt"]
if "message" in func_args:
result = func_args["message"]
else:
ignore_result = True
self.chatgpt_response = None
if prompt is None:
self.chatgpt_messages.append({"role": "assistant", "content": result})
@ -118,4 +127,13 @@ class ChatGptApi:
def clear(self):
self.chatgpt_messages = []
self.chatgpt_response = None
self.log_file_name = None
self.log_file_name = None
def get_stream(self):
if self.recieved_json == '':
return self.recieved_message, None
func_args = force_parse_json(self.recieved_json)
if func_args is not None and "message" in func_args:
return func_args["message"], func_args["prompt"]
else:
return None, None

View File

@ -243,4 +243,7 @@ If you understand, please reply to the following:<|end_of_turn|>
def clear(self):
self.memory.chat_memory.clear()
self.log_file_name = None
self.log_file_name = None
def get_stream(self):
return None, None

View File

@ -9,6 +9,7 @@ import uuid
import copy
import inspect
import sys
import time
import gradio as gr
from PIL import PngImagePlugin
from modules.scripts import basedir
@ -234,9 +235,12 @@ def on_ui_tabs():
images[0].save(last_image_name, pnginfo=(metadata if use_metadata else None))
def append_chat_history(chat_history, text_input_str, result, prompt):
global last_image_name, chat_history_images
global last_image_name, chat_history_images, txt2img_thread
if prompt is not None and prompt != '':
chatgpt_txt2img(prompt)
if txt2img_thread is None:
chatgpt_txt2img(prompt)
else:
txt2img_thread.join()
if result is None:
chat_history_images[len(chat_history)] = last_image_name
chat_history.append((text_input_str, (last_image_name, )))
@ -248,12 +252,54 @@ def on_ui_tabs():
chat_history.append((text_input_str, result))
return chat_history
def chatgpt_generate(text_input_str: str, chat_history):
result, prompt = chat_gpt_api.send(text_input_str)
def recv_stream(chat_history):
global txt2img_thread, stop_recv_thread
txt2img_thread = None
while not stop_recv_thread:
message, prompt = chat_gpt_api.get_stream()
if prompt is not None and prompt != '' and txt2img_thread is None:
txt2img_thread = threading.Thread(target=chatgpt_txt2img, args=(prompt, ))
txt2img_thread.start()
if message is not None:
chat_history[-1][1] = message
time.sleep(0.01)
chat_history = append_chat_history(chat_history, text_input_str, result, prompt)
def chatgpt_generate(chat_history):
global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt
stop_recv_thread = False
return [last_image_name, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('&lt;', '<').replace('&gt;', '>'), '', chat_history]
text_input_str = chat_history[-1][0]
def recv_thread_func():
recv_stream(chat_history)
thread = threading.Thread(target=recv_thread_func)
thread.start()
chatgpt_generate_result = None
chatgpt_generate_prompt = None
def send_thread_func():
global stop_recv_thread, chatgpt_generate_result, chatgpt_generate_prompt
chatgpt_generate_result, chatgpt_generate_prompt = chat_gpt_api.send(text_input_str)
stop_recv_thread = True
thread2 = threading.Thread(target=send_thread_func)
thread2.start()
prev_message = chat_history[-1][1]
while not stop_recv_thread:
if prev_message != chat_history[-1][1]:
prev_message = chat_history[-1][1]
yield chat_history
time.sleep(0.01)
thread.join()
chat_history = chat_history[:-1]
chat_history = append_chat_history(chat_history, text_input_str, chatgpt_generate_result, chatgpt_generate_prompt)
yield chat_history
def chatgpt_generate_finished():
return [last_image_name, info_html, comments_html, info_html.replace('<br>', '\n').replace('<p>', '').replace('</p>', '\n').replace('&lt;', '<').replace('&gt;', '>'), '']
def chatgpt_remove_last(text_input_str: str, chat_history):
if chat_history is None or len(chat_history) <= 0:
@ -463,9 +509,21 @@ def on_ui_tabs():
txt2img_params_base = json.loads(txt2img_params_json)
btn_settings_reflect.click(fn=json_reflect, inputs=txt_json_settings)
btn_generate.click(fn=chatgpt_generate,
inputs=[text_input, chatbot],
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input, chatbot])
btn_generate.click(
fn=lambda t, c: ['', c + [(t, None)]],
inputs=[text_input, chatbot],
outputs=[text_input, chatbot],
#queue=False,
).then(
fn=chatgpt_generate,
inputs=chatbot,
outputs=chatbot,
#queue=False,
).then(
fn=chatgpt_generate_finished,
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, text_input],
#queue=False,
)
btn_regenerate.click(fn=chatgpt_regenerate,
inputs=chatbot,
outputs=[image_gr, info_html_gr, comments_html_gr, info_text_gr, chatbot])