From 694520b098d9fb029cdeff2bd7c79d6a4c3ea3c4 Mon Sep 17 00:00:00 2001 From: NON906 Date: Sat, 20 Jan 2024 22:13:58 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=97=E3=83=AD=E3=83=B3=E3=83=97=E3=83=88?= =?UTF-8?q?=E3=81=AE=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/langchainapi.py | 72 +++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index 4a66ad7..bb335d9 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -19,13 +19,14 @@ from langchain.chains import LLMChain, ConversationChain from langchain.output_parsers import PydanticOutputParser, OutputFixingParser from pydantic import BaseModel, Field, PrivateAttr from typing import Optional, List, Any -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, - AIMessagePromptTemplate, - MessagesPlaceholder, -) +#from langchain.prompts.chat import ( +# ChatPromptTemplate, +# SystemMessagePromptTemplate, +# HumanMessagePromptTemplate, +# AIMessagePromptTemplate, +# MessagesPlaceholder, +#) +from langchain.prompts import StringPromptTemplate, PromptTemplate from langchain.callbacks.base import AsyncCallbackHandler from langchain.callbacks.manager import AsyncCallbackManager #from langchain_community.llms import OpenAI @@ -40,32 +41,26 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler): self.recieved_message += token -class TemplateMessagesPlaceholder(MessagesPlaceholder): - _human_template: str = PrivateAttr('') - _ai_template: str = PrivateAttr('') +class TemplateMessagesPrompt(StringPromptTemplate): + template: str = '' + system_message: str = '' + history_name: str = 'history' + input_name: str = 'input' - def __init__(self, variable_name: str, template: str, **kwargs: Any): - splited = str(template).split("{prompt}") - self._human_template = splited[0] + "{prompt}" - self._ai_template = splited[1] + "{response}" - return super().__init__(variable_name=variable_name, **kwargs) + def format(self, **kwargs: Any) -> str: + splited = self.template.split("{prompt}") + human_template = splited[0] + "{prompt}" + ai_template = splited[1] + "{response}\n" - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - input_mes_list = kwargs[self.variable_name] - messages = [] + input_mes_list = kwargs[self.history_name] + messages = self.system_message + '\n' for mes in input_mes_list: if type(mes) is HumanMessage: - message_prompt = HumanMessagePromptTemplate.from_template( - self._human_template, - ) - message = message_prompt.format(prompt=mes.content) - messages.append(message) + messages += human_template.replace("{prompt}", mes.content) elif type(mes) is AIMessage: - message_prompt = AIMessagePromptTemplate.from_template( - self._ai_template, - ) - message = message_prompt.format(response=mes.content) - messages.append(message) + messages += ai_template.replace("{response}", mes.content) + messages += self.template.replace("{prompt}", kwargs[self.input_name]) + #print(messages) return messages @@ -121,8 +116,7 @@ class LangChainApi: prompt_template_str = self.settings['llama_cpp_prompt_template'] if not is_chat: - template = """ -You are a chatbot having a conversation with a human. + system_message = """You are a chatbot having a conversation with a human. You also have the function to generate image with Stable Diffusion. If you want to use this function, please add the following to your message. @@ -143,21 +137,13 @@ There is no memory function, so please carry over the prompts from past conversa <|end_of_turn|> If you understand, please reply to the following:<|end_of_turn|> """ - system_message_prompt = SystemMessagePromptTemplate.from_template( - template, - ) - human_template = prompt_template_str.replace("{prompt}", "{input}") - human_message_prompt = HumanMessagePromptTemplate.from_template( - human_template, + self.prompt = TemplateMessagesPrompt( + system_message=system_message, + template=prompt_template_str, + input_variables=['history', 'input'], ) - self.prompt = ChatPromptTemplate.from_messages([ - system_message_prompt, - TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str), - human_message_prompt, - ]) - self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True) def chat_predict(human_input): @@ -255,6 +241,8 @@ If you understand, please reply to the following:<|end_of_turn|> return_message, return_prompt = self.parse_message(self.callback.recieved_message) if len(return_message) > 0 and return_message[-1] == '!': return_message = return_message[:-1] + if return_message.isspace(): + return_message = None return return_message, return_prompt def parse_message(self, full_message):