diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index 097374b..3d3eac7 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -5,7 +5,7 @@ import os import sys import json from langchain_community.llms import GPT4All, LlamaCpp -from langchain.memory import ConversationBufferMemory +from langchain.memory import ConversationBufferMemory, ChatMessageHistory from langchain.schema import ( AIMessage, HumanMessage, @@ -42,7 +42,13 @@ class LangChainApi: def __init__(self, **kwargs): self.backend = None - self.memory = None + + self.memory = ConversationBufferMemory( + human_prefix="Human", + ai_prefix="AI", + memory_key="history", + return_messages=True, + ) self.load_settings(**kwargs) @@ -72,13 +78,6 @@ class LangChainApi: ) is_chat = False - self.memory = ConversationBufferMemory( - human_prefix="Human", - ai_prefix="AI", - memory_key="history", - return_messages=True, - ) - self.pydantic_parser = PydanticOutputParser(pydantic_object=Txt2ImgModel) if not is_chat: @@ -176,6 +175,8 @@ If you understand, please reply to the following:<|end_of_turn|> return False def get_log(self): + if self.memory is None: + return '[]' ret_messages = [] for mes in self.memory.chat_memory: if type(mes) is HumanMessage: