出力方法の変更
parent
387b1c6607
commit
2956aad697
|
|
@ -4,7 +4,8 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
from gpt_stream_parser import force_parse_json
|
||||
import re
|
||||
#from gpt_stream_parser import force_parse_json
|
||||
from langchain_community.llms import GPT4All, LlamaCpp
|
||||
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
|
||||
from langchain.schema import (
|
||||
|
|
@ -12,15 +13,17 @@ from langchain.schema import (
|
|||
HumanMessage,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain.chains import LLMChain, ConversationChain
|
||||
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Optional, List, Any
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
|
@ -29,17 +32,6 @@ from langchain.callbacks.manager import AsyncCallbackManager
|
|||
#os.environ['OPENAI_API_KEY'] = 'foo'
|
||||
|
||||
|
||||
class Txt2ImgModel(BaseModel):
|
||||
prompt: Optional[str] = Field(description='''Prompt for generate image.
|
||||
Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.)
|
||||
There is no memory function, so please carry over the prompts from past conversations.
|
||||
Prompt is comma separated keywords such as "1girl, school uniform, red ribbon" (not list).
|
||||
If it is not in English, please translate it into English (lang:en).''')
|
||||
message: str = Field(description='''Chat message.
|
||||
Please enter the content of your reply to me.
|
||||
If prompt is exists, Displayed before the image.''')
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
def __init__(self):
|
||||
self.recieved_message = ''
|
||||
|
|
@ -48,6 +40,35 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
self.recieved_message += token
|
||||
|
||||
|
||||
class TemplateMessagesPlaceholder(MessagesPlaceholder):
|
||||
_human_template: str = PrivateAttr('')
|
||||
_ai_template: str = PrivateAttr('')
|
||||
|
||||
def __init__(self, variable_name: str, template: str, **kwargs: Any):
|
||||
splited = str(template).split("{prompt}")
|
||||
self._human_template = splited[0] + "{prompt}"
|
||||
self._ai_template = splited[1] + "{response}"
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
input_mes_list = kwargs[self.variable_name]
|
||||
messages = []
|
||||
for mes in input_mes_list:
|
||||
if type(mes) is HumanMessage:
|
||||
message_prompt = HumanMessagePromptTemplate.from_template(
|
||||
self._human_template,
|
||||
)
|
||||
message = message_prompt.format(prompt=mes.content)
|
||||
messages.append(message)
|
||||
elif type(mes) is AIMessage:
|
||||
message_prompt = AIMessagePromptTemplate.from_template(
|
||||
self._ai_template,
|
||||
)
|
||||
message = message_prompt.format(response=mes.content)
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
class LangChainApi:
|
||||
log_file_name = None
|
||||
is_sending = False
|
||||
|
|
@ -56,8 +77,8 @@ class LangChainApi:
|
|||
self.backend = None
|
||||
|
||||
self.memory = ConversationBufferMemory(
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
#human_prefix="Human",
|
||||
#ai_prefix="AI",
|
||||
memory_key="history",
|
||||
return_messages=True,
|
||||
)
|
||||
|
|
@ -97,39 +118,43 @@ class LangChainApi:
|
|||
)
|
||||
is_chat = False
|
||||
|
||||
self.pydantic_parser = PydanticOutputParser(pydantic_object=Txt2ImgModel)
|
||||
|
||||
if not is_chat:
|
||||
template = """
|
||||
You are a chatbot having a conversation with a human.
|
||||
|
||||
You also have the function to generate image with Stable Diffusion.
|
||||
If you want to use this function, please add the following to your message.
|
||||
|
||||
{format_instructions}
|
||||

|
||||
|
||||
Below is an example of the final output:
|
||||
```
|
||||
{{
|
||||
"prompt": "1girl, school uniform, red ribbon",
|
||||
"message": "This is a school girl wearing a red ribbon.\nWhat do you think of this image?"
|
||||
}}
|
||||
```
|
||||
PROMPT contains the prompt to generate the image.
|
||||
Prompt is comma separated keywords.
|
||||
If it is not in English, please translate it into English (lang:en).
|
||||
For example, if you want to output "a school girl wearing a red ribbon", it would be as follows.
|
||||
|
||||

|
||||
|
||||
This prompt is automatically deleted and is not visible to the human.
|
||||
The image is always output at the end, not at the location where it is added.
|
||||
If there are multiple entries, only the first one will be reflected.
|
||||
There is no memory function, so please carry over the prompts from past conversations.
|
||||
<|end_of_turn|>
|
||||
If you understand, please reply to the following:<|end_of_turn|>
|
||||
"""
|
||||
format_instructions = self.pydantic_parser.get_format_instructions()
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(
|
||||
template,
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
|
||||
human_template = "{input}<|end_of_turn|>AI:"
|
||||
prompt_template_str = "Human: {prompt}<|end_of_turn|>AI: "
|
||||
|
||||
human_template = prompt_template_str.replace("{prompt}", "{input}")
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(
|
||||
human_template,
|
||||
)
|
||||
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
system_message_prompt,
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str),
|
||||
human_message_prompt,
|
||||
])
|
||||
|
||||
|
|
@ -144,11 +169,6 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
|
||||
self.chat_predict = chat_predict
|
||||
|
||||
self.parser = OutputFixingParser.from_llm(
|
||||
parser=self.pydantic_parser,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
self.is_inited = True
|
||||
|
||||
def load_settings(self, **kwargs):
|
||||
|
|
@ -211,15 +231,10 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
self.is_sending = True
|
||||
|
||||
result = self.chat_predict(human_input=content)
|
||||
try:
|
||||
parse_result = self.parser.parse(result)
|
||||
return_message = parse_result.message
|
||||
return_prompt = parse_result.prompt
|
||||
except:
|
||||
return_message = result
|
||||
return_prompt = None
|
||||
return_message, return_prompt = self.parse_message(result)
|
||||
|
||||
self.is_sending = False
|
||||
self.callback.recieved_message = ''
|
||||
|
||||
return return_message, return_prompt
|
||||
|
||||
|
|
@ -237,18 +252,28 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
def get_stream(self):
|
||||
if self.callback is None:
|
||||
return None, None
|
||||
if '{' in self.callback.recieved_message:
|
||||
if '}' in self.callback.recieved_message:
|
||||
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1]
|
||||
else:
|
||||
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):]
|
||||
recieved_dict = force_parse_json(recieved_json)
|
||||
if recieved_dict is not None and "message" in recieved_dict:
|
||||
if "prompt" in recieved_dict:
|
||||
return recieved_dict["message"], recieved_dict["prompt"]
|
||||
else:
|
||||
return recieved_dict["message"], None
|
||||
else:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
return_message, return_prompt = self.parse_message(self.callback.recieved_message)
|
||||
if len(return_message) > 0 and return_message[-1] == '!':
|
||||
return_message = return_message[:-1]
|
||||
return return_message, return_prompt
|
||||
|
||||
def parse_message(self, full_message):
|
||||
if not '![' in full_message:
|
||||
return full_message, None
|
||||
prompt_tags = re.findall('\!\[.*?\]\(.*?\)', full_message)
|
||||
if len(prompt_tags) <= 0:
|
||||
end_index = full_message.rfind('![')
|
||||
return full_message[:end_index], None
|
||||
ret_message = full_message
|
||||
prompt = None
|
||||
for tag in prompt_tags:
|
||||
if '](sd-prompt:// "result")' in tag:
|
||||
if prompt is None:
|
||||
prompt = tag[len('')]
|
||||
ret_message = ret_message.replace(tag, '')
|
||||
if ret_message == '':
|
||||
return None, prompt
|
||||
end_index = ret_message.rfind('![')
|
||||
if end_index >= 0:
|
||||
return ret_message[:end_index], prompt
|
||||
return ret_message, prompt
|
||||
Loading…
Reference in New Issue