プロンプトの修正

main
NON906 2024-01-20 22:13:58 +09:00
parent 830458403b
commit 694520b098
1 changed files with 30 additions and 42 deletions

View File

@ -19,13 +19,14 @@ from langchain.chains import LLMChain, ConversationChain
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, List, Any
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
MessagesPlaceholder,
)
#from langchain.prompts.chat import (
# ChatPromptTemplate,
# SystemMessagePromptTemplate,
# HumanMessagePromptTemplate,
# AIMessagePromptTemplate,
# MessagesPlaceholder,
#)
from langchain.prompts import StringPromptTemplate, PromptTemplate
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import AsyncCallbackManager
#from langchain_community.llms import OpenAI
@ -40,32 +41,26 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
self.recieved_message += token
class TemplateMessagesPlaceholder(MessagesPlaceholder):
_human_template: str = PrivateAttr('')
_ai_template: str = PrivateAttr('')
class TemplateMessagesPrompt(StringPromptTemplate):
template: str = ''
system_message: str = ''
history_name: str = 'history'
input_name: str = 'input'
def __init__(self, variable_name: str, template: str, **kwargs: Any):
splited = str(template).split("{prompt}")
self._human_template = splited[0] + "{prompt}"
self._ai_template = splited[1] + "{response}"
return super().__init__(variable_name=variable_name, **kwargs)
def format(self, **kwargs: Any) -> str:
splited = self.template.split("{prompt}")
human_template = splited[0] + "{prompt}"
ai_template = splited[1] + "{response}\n"
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
input_mes_list = kwargs[self.variable_name]
messages = []
input_mes_list = kwargs[self.history_name]
messages = self.system_message + '\n'
for mes in input_mes_list:
if type(mes) is HumanMessage:
message_prompt = HumanMessagePromptTemplate.from_template(
self._human_template,
)
message = message_prompt.format(prompt=mes.content)
messages.append(message)
messages += human_template.replace("{prompt}", mes.content)
elif type(mes) is AIMessage:
message_prompt = AIMessagePromptTemplate.from_template(
self._ai_template,
)
message = message_prompt.format(response=mes.content)
messages.append(message)
messages += ai_template.replace("{response}", mes.content)
messages += self.template.replace("{prompt}", kwargs[self.input_name])
#print(messages)
return messages
@ -121,8 +116,7 @@ class LangChainApi:
prompt_template_str = self.settings['llama_cpp_prompt_template']
if not is_chat:
template = """
You are a chatbot having a conversation with a human.
system_message = """You are a chatbot having a conversation with a human.
You also have the function to generate image with Stable Diffusion.
If you want to use this function, please add the following to your message.
@ -143,21 +137,13 @@ There is no memory function, so please carry over the prompts from past conversa
<|end_of_turn|>
If you understand, please reply to the following:<|end_of_turn|>
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(
template,
)
human_template = prompt_template_str.replace("{prompt}", "{input}")
human_message_prompt = HumanMessagePromptTemplate.from_template(
human_template,
self.prompt = TemplateMessagesPrompt(
system_message=system_message,
template=prompt_template_str,
input_variables=['history', 'input'],
)
self.prompt = ChatPromptTemplate.from_messages([
system_message_prompt,
TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str),
human_message_prompt,
])
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
def chat_predict(human_input):
@ -255,6 +241,8 @@ If you understand, please reply to the following:<|end_of_turn|>
return_message, return_prompt = self.parse_message(self.callback.recieved_message)
if len(return_message) > 0 and return_message[-1] == '!':
return_message = return_message[:-1]
if return_message.isspace():
return_message = None
return return_message, return_prompt
def parse_message(self, full_message):