プロンプトの修正
parent
830458403b
commit
694520b098
|
|
@ -19,13 +19,14 @@ from langchain.chains import LLMChain, ConversationChain
|
|||
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Optional, List, Any
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
#from langchain.prompts.chat import (
|
||||
# ChatPromptTemplate,
|
||||
# SystemMessagePromptTemplate,
|
||||
# HumanMessagePromptTemplate,
|
||||
# AIMessagePromptTemplate,
|
||||
# MessagesPlaceholder,
|
||||
#)
|
||||
from langchain.prompts import StringPromptTemplate, PromptTemplate
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
#from langchain_community.llms import OpenAI
|
||||
|
|
@ -40,32 +41,26 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
self.recieved_message += token
|
||||
|
||||
|
||||
class TemplateMessagesPlaceholder(MessagesPlaceholder):
|
||||
_human_template: str = PrivateAttr('')
|
||||
_ai_template: str = PrivateAttr('')
|
||||
class TemplateMessagesPrompt(StringPromptTemplate):
|
||||
template: str = ''
|
||||
system_message: str = ''
|
||||
history_name: str = 'history'
|
||||
input_name: str = 'input'
|
||||
|
||||
def __init__(self, variable_name: str, template: str, **kwargs: Any):
|
||||
splited = str(template).split("{prompt}")
|
||||
self._human_template = splited[0] + "{prompt}"
|
||||
self._ai_template = splited[1] + "{response}"
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
splited = self.template.split("{prompt}")
|
||||
human_template = splited[0] + "{prompt}"
|
||||
ai_template = splited[1] + "{response}\n"
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
input_mes_list = kwargs[self.variable_name]
|
||||
messages = []
|
||||
input_mes_list = kwargs[self.history_name]
|
||||
messages = self.system_message + '\n'
|
||||
for mes in input_mes_list:
|
||||
if type(mes) is HumanMessage:
|
||||
message_prompt = HumanMessagePromptTemplate.from_template(
|
||||
self._human_template,
|
||||
)
|
||||
message = message_prompt.format(prompt=mes.content)
|
||||
messages.append(message)
|
||||
messages += human_template.replace("{prompt}", mes.content)
|
||||
elif type(mes) is AIMessage:
|
||||
message_prompt = AIMessagePromptTemplate.from_template(
|
||||
self._ai_template,
|
||||
)
|
||||
message = message_prompt.format(response=mes.content)
|
||||
messages.append(message)
|
||||
messages += ai_template.replace("{response}", mes.content)
|
||||
messages += self.template.replace("{prompt}", kwargs[self.input_name])
|
||||
#print(messages)
|
||||
return messages
|
||||
|
||||
|
||||
|
|
@ -121,8 +116,7 @@ class LangChainApi:
|
|||
prompt_template_str = self.settings['llama_cpp_prompt_template']
|
||||
|
||||
if not is_chat:
|
||||
template = """
|
||||
You are a chatbot having a conversation with a human.
|
||||
system_message = """You are a chatbot having a conversation with a human.
|
||||
|
||||
You also have the function to generate image with Stable Diffusion.
|
||||
If you want to use this function, please add the following to your message.
|
||||
|
|
@ -143,21 +137,13 @@ There is no memory function, so please carry over the prompts from past conversa
|
|||
<|end_of_turn|>
|
||||
If you understand, please reply to the following:<|end_of_turn|>
|
||||
"""
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(
|
||||
template,
|
||||
)
|
||||
|
||||
human_template = prompt_template_str.replace("{prompt}", "{input}")
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(
|
||||
human_template,
|
||||
self.prompt = TemplateMessagesPrompt(
|
||||
system_message=system_message,
|
||||
template=prompt_template_str,
|
||||
input_variables=['history', 'input'],
|
||||
)
|
||||
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
system_message_prompt,
|
||||
TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str),
|
||||
human_message_prompt,
|
||||
])
|
||||
|
||||
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
|
||||
|
||||
def chat_predict(human_input):
|
||||
|
|
@ -255,6 +241,8 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
return_message, return_prompt = self.parse_message(self.callback.recieved_message)
|
||||
if len(return_message) > 0 and return_message[-1] == '!':
|
||||
return_message = return_message[:-1]
|
||||
if return_message.isspace():
|
||||
return_message = None
|
||||
return return_message, return_prompt
|
||||
|
||||
def parse_message(self, full_message):
|
||||
|
|
|
|||
Loading…
Reference in New Issue