langchainのstream対応。微修正

hr_error
NON906 2024-01-11 11:28:16 +09:00
parent edd649a43d
commit d9faeff70d
2 changed files with 40 additions and 3 deletions

View File

@ -134,6 +134,9 @@ class ChatGptApi:
return self.recieved_message, None
func_args = force_parse_json(self.recieved_json)
if func_args is not None and "message" in func_args:
return func_args["message"], func_args["prompt"]
if "prompt" in func_args:
return func_args["message"], func_args["prompt"]
else:
return func_args["message"], None
else:
return None, None

View File

@ -4,6 +4,7 @@
import os
import sys
import json
from gpt_stream_parser import force_parse_json
from langchain_community.llms import GPT4All, LlamaCpp
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
from langchain.schema import (
@ -22,9 +23,12 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import AsyncCallbackManager
#from langchain_community.llms import OpenAI
#os.environ['OPENAI_API_KEY'] = 'foo'
class Txt2ImgModel(BaseModel):
prompt: Optional[str] = Field(description='''Prompt for generate image.
Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.)
@ -36,6 +40,14 @@ Please enter the content of your reply to me.
If prompt is exists, Displayed before the image.''')
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
def __init__(self):
self.recieved_message = ''
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.recieved_message += token
class LangChainApi:
log_file_name = None
is_sending = False
@ -56,10 +68,16 @@ class LangChainApi:
if self.backend is None:
return
self.callback = StreamingLLMCallbackHandler()
if self.backend == 'GPT4All':
if (not 'gpt4all_model' in self.settings) or (self.settings['gpt4all_model'] is None):
return
self.llm = GPT4All(model=self.settings['gpt4all_model'])
self.llm = GPT4All(
model=self.settings['gpt4all_model'],
streaming=True,
callback_manager=AsyncCallbackManager([self.callback]),
)
#self.llm = OpenAI(model_name="gpt-3.5-turbo")
is_chat = False
if self.backend == 'LlamaCpp':
@ -74,6 +92,8 @@ class LangChainApi:
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
n_batch=self.settings['llama_cpp_n_batch'],
n_ctx=2048,
streaming=True,
callback_manager=AsyncCallbackManager([self.callback]),
#verbose=True,
)
is_chat = False
@ -246,4 +266,18 @@ If you understand, please reply to the following:<|end_of_turn|>
self.log_file_name = None
def get_stream(self):
return None, None
if self.callback is None:
return None, None
if '{' in self.callback.recieved_message:
if '}' in self.callback.recieved_message:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1]
else:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):]
recieved_dict = force_parse_json(recieved_json)
if recieved_dict is not None and "message" in recieved_dict:
if "prompt" in recieved_dict:
return recieved_dict["message"], recieved_dict["prompt"]
else:
return recieved_dict["message"], None
else:
return None, None