diff --git a/scripts/chatgptapi.py b/scripts/chatgptapi.py index a831597..fcf68b1 100644 --- a/scripts/chatgptapi.py +++ b/scripts/chatgptapi.py @@ -134,6 +134,9 @@ class ChatGptApi: return self.recieved_message, None func_args = force_parse_json(self.recieved_json) if func_args is not None and "message" in func_args: - return func_args["message"], func_args["prompt"] + if "prompt" in func_args: + return func_args["message"], func_args["prompt"] + else: + return func_args["message"], None else: return None, None diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index 16951bc..687c99f 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -4,6 +4,7 @@ import os import sys import json +from gpt_stream_parser import force_parse_json from langchain_community.llms import GPT4All, LlamaCpp from langchain.memory import ConversationBufferMemory, ChatMessageHistory from langchain.schema import ( @@ -22,9 +23,12 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, MessagesPlaceholder, ) +from langchain.callbacks.base import AsyncCallbackHandler +from langchain.callbacks.manager import AsyncCallbackManager #from langchain_community.llms import OpenAI #os.environ['OPENAI_API_KEY'] = 'foo' + class Txt2ImgModel(BaseModel): prompt: Optional[str] = Field(description='''Prompt for generate image. Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.) @@ -36,6 +40,14 @@ Please enter the content of your reply to me. If prompt is exists, Displayed before the image.''') +class StreamingLLMCallbackHandler(AsyncCallbackHandler): + def __init__(self): + self.recieved_message = '' + + def on_llm_new_token(self, token: str, **kwargs) -> None: + self.recieved_message += token + + class LangChainApi: log_file_name = None is_sending = False @@ -56,10 +68,16 @@ class LangChainApi: if self.backend is None: return + self.callback = StreamingLLMCallbackHandler() + if self.backend == 'GPT4All': if (not 'gpt4all_model' in self.settings) or (self.settings['gpt4all_model'] is None): return - self.llm = GPT4All(model=self.settings['gpt4all_model']) + self.llm = GPT4All( + model=self.settings['gpt4all_model'], + streaming=True, + callback_manager=AsyncCallbackManager([self.callback]), + ) #self.llm = OpenAI(model_name="gpt-3.5-turbo") is_chat = False if self.backend == 'LlamaCpp': @@ -74,6 +92,8 @@ class LangChainApi: n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'], n_batch=self.settings['llama_cpp_n_batch'], n_ctx=2048, + streaming=True, + callback_manager=AsyncCallbackManager([self.callback]), #verbose=True, ) is_chat = False @@ -246,4 +266,18 @@ If you understand, please reply to the following:<|end_of_turn|> self.log_file_name = None def get_stream(self): - return None, None \ No newline at end of file + if self.callback is None: + return None, None + if '{' in self.callback.recieved_message: + if '}' in self.callback.recieved_message: + recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1] + else: + recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):] + recieved_dict = force_parse_json(recieved_json) + if recieved_dict is not None and "message" in recieved_dict: + if "prompt" in recieved_dict: + return recieved_dict["message"], recieved_dict["prompt"] + else: + return recieved_dict["message"], None + else: + return None, None \ No newline at end of file