langchainのstream対応。微修正
parent
edd649a43d
commit
d9faeff70d
|
|
@ -134,6 +134,9 @@ class ChatGptApi:
|
|||
return self.recieved_message, None
|
||||
func_args = force_parse_json(self.recieved_json)
|
||||
if func_args is not None and "message" in func_args:
|
||||
return func_args["message"], func_args["prompt"]
|
||||
if "prompt" in func_args:
|
||||
return func_args["message"], func_args["prompt"]
|
||||
else:
|
||||
return func_args["message"], None
|
||||
else:
|
||||
return None, None
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
from gpt_stream_parser import force_parse_json
|
||||
from langchain_community.llms import GPT4All, LlamaCpp
|
||||
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
|
||||
from langchain.schema import (
|
||||
|
|
@ -22,9 +23,12 @@ from langchain.prompts.chat import (
|
|||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
#from langchain_community.llms import OpenAI
|
||||
#os.environ['OPENAI_API_KEY'] = 'foo'
|
||||
|
||||
|
||||
class Txt2ImgModel(BaseModel):
|
||||
prompt: Optional[str] = Field(description='''Prompt for generate image.
|
||||
Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.)
|
||||
|
|
@ -36,6 +40,14 @@ Please enter the content of your reply to me.
|
|||
If prompt is exists, Displayed before the image.''')
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
def __init__(self):
|
||||
self.recieved_message = ''
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
||||
self.recieved_message += token
|
||||
|
||||
|
||||
class LangChainApi:
|
||||
log_file_name = None
|
||||
is_sending = False
|
||||
|
|
@ -56,10 +68,16 @@ class LangChainApi:
|
|||
if self.backend is None:
|
||||
return
|
||||
|
||||
self.callback = StreamingLLMCallbackHandler()
|
||||
|
||||
if self.backend == 'GPT4All':
|
||||
if (not 'gpt4all_model' in self.settings) or (self.settings['gpt4all_model'] is None):
|
||||
return
|
||||
self.llm = GPT4All(model=self.settings['gpt4all_model'])
|
||||
self.llm = GPT4All(
|
||||
model=self.settings['gpt4all_model'],
|
||||
streaming=True,
|
||||
callback_manager=AsyncCallbackManager([self.callback]),
|
||||
)
|
||||
#self.llm = OpenAI(model_name="gpt-3.5-turbo")
|
||||
is_chat = False
|
||||
if self.backend == 'LlamaCpp':
|
||||
|
|
@ -74,6 +92,8 @@ class LangChainApi:
|
|||
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
|
||||
n_batch=self.settings['llama_cpp_n_batch'],
|
||||
n_ctx=2048,
|
||||
streaming=True,
|
||||
callback_manager=AsyncCallbackManager([self.callback]),
|
||||
#verbose=True,
|
||||
)
|
||||
is_chat = False
|
||||
|
|
@ -246,4 +266,18 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
self.log_file_name = None
|
||||
|
||||
def get_stream(self):
|
||||
return None, None
|
||||
if self.callback is None:
|
||||
return None, None
|
||||
if '{' in self.callback.recieved_message:
|
||||
if '}' in self.callback.recieved_message:
|
||||
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1]
|
||||
else:
|
||||
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):]
|
||||
recieved_dict = force_parse_json(recieved_json)
|
||||
if recieved_dict is not None and "message" in recieved_dict:
|
||||
if "prompt" in recieved_dict:
|
||||
return recieved_dict["message"], recieved_dict["prompt"]
|
||||
else:
|
||||
return recieved_dict["message"], None
|
||||
else:
|
||||
return None, None
|
||||
Loading…
Reference in New Issue