出力方法の変更

main
NON906 2024-01-20 15:58:35 +09:00
parent 387b1c6607
commit 2956aad697
1 changed files with 82 additions and 57 deletions

View File

@ -4,7 +4,8 @@
import os
import sys
import json
from gpt_stream_parser import force_parse_json
import re
#from gpt_stream_parser import force_parse_json
from langchain_community.llms import GPT4All, LlamaCpp
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
from langchain.schema import (
@ -12,15 +13,17 @@ from langchain.schema import (
HumanMessage,
messages_from_dict,
messages_to_dict,
BaseMessage,
)
from langchain.chains import LLMChain, ConversationChain
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from pydantic import BaseModel, Field
from typing import Optional
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, List, Any
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.callbacks.base import AsyncCallbackHandler
@ -29,17 +32,6 @@ from langchain.callbacks.manager import AsyncCallbackManager
#os.environ['OPENAI_API_KEY'] = 'foo'
class Txt2ImgModel(BaseModel):
prompt: Optional[str] = Field(description='''Prompt for generate image.
Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.)
There is no memory function, so please carry over the prompts from past conversations.
Prompt is comma separated keywords such as "1girl, school uniform, red ribbon" (not list).
If it is not in English, please translate it into English (lang:en).''')
message: str = Field(description='''Chat message.
Please enter the content of your reply to me.
If prompt is exists, Displayed before the image.''')
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
def __init__(self):
self.recieved_message = ''
@ -48,6 +40,35 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
self.recieved_message += token
class TemplateMessagesPlaceholder(MessagesPlaceholder):
_human_template: str = PrivateAttr('')
_ai_template: str = PrivateAttr('')
def __init__(self, variable_name: str, template: str, **kwargs: Any):
splited = str(template).split("{prompt}")
self._human_template = splited[0] + "{prompt}"
self._ai_template = splited[1] + "{response}"
return super().__init__(variable_name=variable_name, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
input_mes_list = kwargs[self.variable_name]
messages = []
for mes in input_mes_list:
if type(mes) is HumanMessage:
message_prompt = HumanMessagePromptTemplate.from_template(
self._human_template,
)
message = message_prompt.format(prompt=mes.content)
messages.append(message)
elif type(mes) is AIMessage:
message_prompt = AIMessagePromptTemplate.from_template(
self._ai_template,
)
message = message_prompt.format(response=mes.content)
messages.append(message)
return messages
class LangChainApi:
log_file_name = None
is_sending = False
@ -56,8 +77,8 @@ class LangChainApi:
self.backend = None
self.memory = ConversationBufferMemory(
human_prefix="Human",
ai_prefix="AI",
#human_prefix="Human",
#ai_prefix="AI",
memory_key="history",
return_messages=True,
)
@ -97,39 +118,43 @@ class LangChainApi:
)
is_chat = False
self.pydantic_parser = PydanticOutputParser(pydantic_object=Txt2ImgModel)
if not is_chat:
template = """
You are a chatbot having a conversation with a human.
You also have the function to generate image with Stable Diffusion.
If you want to use this function, please add the following to your message.
{format_instructions}
![PROMPT](sd-prompt:// "result")
Below is an example of the final output:
```
{{
"prompt": "1girl, school uniform, red ribbon",
"message": "This is a school girl wearing a red ribbon.\nWhat do you think of this image?"
}}
```
PROMPT contains the prompt to generate the image.
Prompt is comma separated keywords.
If it is not in English, please translate it into English (lang:en).
For example, if you want to output "a school girl wearing a red ribbon", it would be as follows.
![1girl, school uniform, red ribbon](sd-prompt:// "result")
This prompt is automatically deleted and is not visible to the human.
The image is always output at the end, not at the location where it is added.
If there are multiple entries, only the first one will be reflected.
There is no memory function, so please carry over the prompts from past conversations.
<|end_of_turn|>
If you understand, please reply to the following:<|end_of_turn|>
"""
format_instructions = self.pydantic_parser.get_format_instructions()
system_message_prompt = SystemMessagePromptTemplate.from_template(
template,
partial_variables={"format_instructions": format_instructions},
)
human_template = "{input}<|end_of_turn|>AI:"
prompt_template_str = "Human: {prompt}<|end_of_turn|>AI: "
human_template = prompt_template_str.replace("{prompt}", "{input}")
human_message_prompt = HumanMessagePromptTemplate.from_template(
human_template,
)
self.prompt = ChatPromptTemplate.from_messages([
system_message_prompt,
MessagesPlaceholder(variable_name="history"),
TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str),
human_message_prompt,
])
@ -144,11 +169,6 @@ If you understand, please reply to the following:<|end_of_turn|>
self.chat_predict = chat_predict
self.parser = OutputFixingParser.from_llm(
parser=self.pydantic_parser,
llm=self.llm,
)
self.is_inited = True
def load_settings(self, **kwargs):
@ -211,15 +231,10 @@ If you understand, please reply to the following:<|end_of_turn|>
self.is_sending = True
result = self.chat_predict(human_input=content)
try:
parse_result = self.parser.parse(result)
return_message = parse_result.message
return_prompt = parse_result.prompt
except:
return_message = result
return_prompt = None
return_message, return_prompt = self.parse_message(result)
self.is_sending = False
self.callback.recieved_message = ''
return return_message, return_prompt
@ -237,18 +252,28 @@ If you understand, please reply to the following:<|end_of_turn|>
def get_stream(self):
if self.callback is None:
return None, None
if '{' in self.callback.recieved_message:
if '}' in self.callback.recieved_message:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1]
else:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):]
recieved_dict = force_parse_json(recieved_json)
if recieved_dict is not None and "message" in recieved_dict:
if "prompt" in recieved_dict:
return recieved_dict["message"], recieved_dict["prompt"]
else:
return recieved_dict["message"], None
else:
return None, None
else:
return None, None
return_message, return_prompt = self.parse_message(self.callback.recieved_message)
if len(return_message) > 0 and return_message[-1] == '!':
return_message = return_message[:-1]
return return_message, return_prompt
def parse_message(self, full_message):
if not '![' in full_message:
return full_message, None
prompt_tags = re.findall('\!\[.*?\]\(.*?\)', full_message)
if len(prompt_tags) <= 0:
end_index = full_message.rfind('![')
return full_message[:end_index], None
ret_message = full_message
prompt = None
for tag in prompt_tags:
if '](sd-prompt:// "result")' in tag:
if prompt is None:
prompt = tag[len('!['):-len('](sd-prompt:// "result")')]
ret_message = ret_message.replace(tag, '')
if ret_message == '':
return None, prompt
end_index = ret_message.rfind('![')
if end_index >= 0:
return ret_message[:end_index], prompt
return ret_message, prompt