279 lines
10 KiB
Python
279 lines
10 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import re
|
|
#from gpt_stream_parser import force_parse_json
|
|
from langchain_community.llms import GPT4All, LlamaCpp
|
|
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
HumanMessage,
|
|
messages_from_dict,
|
|
messages_to_dict,
|
|
BaseMessage,
|
|
)
|
|
from langchain.chains import LLMChain, ConversationChain
|
|
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
from typing import Optional, List, Any
|
|
from langchain.prompts.chat import (
|
|
ChatPromptTemplate,
|
|
SystemMessagePromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
AIMessagePromptTemplate,
|
|
MessagesPlaceholder,
|
|
)
|
|
from langchain.callbacks.base import AsyncCallbackHandler
|
|
from langchain.callbacks.manager import AsyncCallbackManager
|
|
#from langchain_community.llms import OpenAI
|
|
#os.environ['OPENAI_API_KEY'] = 'foo'
|
|
|
|
|
|
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|
def __init__(self):
|
|
self.recieved_message = ''
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
self.recieved_message += token
|
|
|
|
|
|
class TemplateMessagesPlaceholder(MessagesPlaceholder):
|
|
_human_template: str = PrivateAttr('')
|
|
_ai_template: str = PrivateAttr('')
|
|
|
|
def __init__(self, variable_name: str, template: str, **kwargs: Any):
|
|
splited = str(template).split("{prompt}")
|
|
self._human_template = splited[0] + "{prompt}"
|
|
self._ai_template = splited[1] + "{response}"
|
|
return super().__init__(variable_name=variable_name, **kwargs)
|
|
|
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
|
input_mes_list = kwargs[self.variable_name]
|
|
messages = []
|
|
for mes in input_mes_list:
|
|
if type(mes) is HumanMessage:
|
|
message_prompt = HumanMessagePromptTemplate.from_template(
|
|
self._human_template,
|
|
)
|
|
message = message_prompt.format(prompt=mes.content)
|
|
messages.append(message)
|
|
elif type(mes) is AIMessage:
|
|
message_prompt = AIMessagePromptTemplate.from_template(
|
|
self._ai_template,
|
|
)
|
|
message = message_prompt.format(response=mes.content)
|
|
messages.append(message)
|
|
return messages
|
|
|
|
|
|
class LangChainApi:
|
|
log_file_name = None
|
|
is_sending = False
|
|
|
|
def __init__(self, **kwargs):
|
|
self.backend = None
|
|
|
|
self.memory = ConversationBufferMemory(
|
|
#human_prefix="Human",
|
|
#ai_prefix="AI",
|
|
memory_key="history",
|
|
return_messages=True,
|
|
)
|
|
self.callback = StreamingLLMCallbackHandler()
|
|
|
|
self.load_settings(**kwargs)
|
|
|
|
def init_model(self):
|
|
if self.backend is None:
|
|
return
|
|
|
|
if self.backend == 'GPT4All':
|
|
if (not 'gpt4all_model' in self.settings) or (self.settings['gpt4all_model'] is None):
|
|
return
|
|
self.llm = GPT4All(
|
|
model=self.settings['gpt4all_model'],
|
|
streaming=True,
|
|
callback_manager=AsyncCallbackManager([self.callback]),
|
|
)
|
|
#self.llm = OpenAI(model_name="gpt-3.5-turbo")
|
|
is_chat = False
|
|
if self.backend == 'LlamaCpp':
|
|
if (not 'llama_cpp_model' in self.settings) or (self.settings['llama_cpp_model'] is None):
|
|
return
|
|
if not 'llama_cpp_n_gpu_layers' in self.settings:
|
|
self.settings['llama_cpp_n_gpu_layers'] = 20
|
|
if not 'llama_cpp_n_batch' in self.settings:
|
|
self.settings['llama_cpp_n_batch'] = 128
|
|
self.llm = LlamaCpp(
|
|
model_path=self.settings['llama_cpp_model'],
|
|
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
|
|
n_batch=self.settings['llama_cpp_n_batch'],
|
|
n_ctx=2048,
|
|
streaming=True,
|
|
callback_manager=AsyncCallbackManager([self.callback]),
|
|
#verbose=True,
|
|
)
|
|
is_chat = False
|
|
|
|
if not is_chat:
|
|
template = """
|
|
You are a chatbot having a conversation with a human.
|
|
|
|
You also have the function to generate image with Stable Diffusion.
|
|
If you want to use this function, please add the following to your message.
|
|
|
|

|
|
|
|
PROMPT contains the prompt to generate the image.
|
|
Prompt is comma separated keywords.
|
|
If it is not in English, please translate it into English (lang:en).
|
|
For example, if you want to output "a school girl wearing a red ribbon", it would be as follows.
|
|
|
|

|
|
|
|
This prompt is automatically deleted and is not visible to the human.
|
|
The image is always output at the end, not at the location where it is added.
|
|
If there are multiple entries, only the first one will be reflected.
|
|
There is no memory function, so please carry over the prompts from past conversations.
|
|
<|end_of_turn|>
|
|
If you understand, please reply to the following:<|end_of_turn|>
|
|
"""
|
|
system_message_prompt = SystemMessagePromptTemplate.from_template(
|
|
template,
|
|
)
|
|
|
|
prompt_template_str = "Human: {prompt}<|end_of_turn|>AI: "
|
|
|
|
human_template = prompt_template_str.replace("{prompt}", "{input}")
|
|
human_message_prompt = HumanMessagePromptTemplate.from_template(
|
|
human_template,
|
|
)
|
|
|
|
self.prompt = ChatPromptTemplate.from_messages([
|
|
system_message_prompt,
|
|
TemplateMessagesPlaceholder(variable_name="history", template=prompt_template_str),
|
|
human_message_prompt,
|
|
])
|
|
|
|
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
|
|
|
|
def chat_predict(human_input):
|
|
ret = self.llm_chain.invoke({
|
|
'input': human_input,
|
|
})
|
|
#print(ret)
|
|
return ret['response']
|
|
|
|
self.chat_predict = chat_predict
|
|
|
|
self.is_inited = True
|
|
|
|
def load_settings(self, **kwargs):
|
|
self.settings = kwargs
|
|
self.backend = self.settings['backend']
|
|
self.is_inited = False
|
|
|
|
def set_log(self, log_string):
|
|
chatgpt_messages = json.loads(log_string)
|
|
if not 'log_version' in chatgpt_messages:
|
|
history = ChatMessageHistory()
|
|
for mes in chatgpt_messages:
|
|
if mes['role'] == 'user':
|
|
history.add_user_message(mes['content'])
|
|
elif mes['role'] == 'assistant':
|
|
if '\n(Generated image by the following prompt: ' in mes['content']:
|
|
mes_content, mes_prompt = mes['content'].split('\n(Generated image by the following prompt: ')
|
|
mes_prompt = mes_prompt[::-1].replace(')', '', 1)[::-1]
|
|
mes_json = json.dumps({
|
|
"prompt": mes_prompt,
|
|
"message": mes_content,
|
|
})
|
|
else:
|
|
mes_json = json.dumps({
|
|
"message": mes['content'],
|
|
})
|
|
history.add_ai_message(mes_json)
|
|
self.memory.chat_memory = history
|
|
elif chatgpt_messages['log_version'] == 2:
|
|
self.memory.chat_memory = messages_from_dict(chatgpt_messages['messages'])
|
|
|
|
|
|
def get_log(self):
|
|
if self.memory is None:
|
|
return '[]'
|
|
ret_messages = []
|
|
for mes in self.memory.chat_memory:
|
|
if type(mes) is HumanMessage:
|
|
ret_messages.append({"role": "user", "content": mes.content})
|
|
elif type(mes) is AIMessage:
|
|
mes_dict = json.loads(mes.content)
|
|
add_mes = mes_dict['message']
|
|
if 'prompt' in mes_dict and mes_dict['prompt'] is not None and mes_dict['prompt'] != "":
|
|
add_mes += "\n(Generated image by the following prompt: " + mes_dict['prompt'] + ")"
|
|
ret_messages.append({"role": "assistant", "content": add_mes})
|
|
return json.dumps(ret_messages)
|
|
#dicts = {'log_version': 2}
|
|
#if self.memory is None:
|
|
# dicts['messages'] = {}
|
|
#else:
|
|
# dicts['messages'] = messages_to_dict(self.memory.chat_memory)
|
|
#return json.dumps(dicts)
|
|
|
|
def send(self, content):
|
|
if not self.is_inited:
|
|
self.init_model()
|
|
|
|
if self.is_sending:
|
|
return
|
|
self.is_sending = True
|
|
|
|
result = self.chat_predict(human_input=content)
|
|
return_message, return_prompt = self.parse_message(result)
|
|
|
|
self.is_sending = False
|
|
self.callback.recieved_message = ''
|
|
|
|
return return_message, return_prompt
|
|
|
|
def remove_last_conversation(self, result=None):
|
|
if result is None or self.memory.chat_memory.messages[-1].content == result:
|
|
if len(self.memory.chat_memory.messages) > 2:
|
|
self.memory.chat_memory.messages = self.memory.chat_memory.messages[:-2]
|
|
else:
|
|
self.memory.chat_memory.messages.clear()
|
|
|
|
def clear(self):
|
|
self.memory.chat_memory.clear()
|
|
self.log_file_name = None
|
|
|
|
def get_stream(self):
|
|
if self.callback is None:
|
|
return None, None
|
|
return_message, return_prompt = self.parse_message(self.callback.recieved_message)
|
|
if len(return_message) > 0 and return_message[-1] == '!':
|
|
return_message = return_message[:-1]
|
|
return return_message, return_prompt
|
|
|
|
def parse_message(self, full_message):
|
|
if not '![' in full_message:
|
|
return full_message, None
|
|
prompt_tags = re.findall('\!\[.*?\]\(.*?\)', full_message)
|
|
if len(prompt_tags) <= 0:
|
|
end_index = full_message.rfind('![')
|
|
return full_message[:end_index], None
|
|
ret_message = full_message
|
|
prompt = None
|
|
for tag in prompt_tags:
|
|
if '](sd-prompt:// "result")' in tag:
|
|
if prompt is None:
|
|
prompt = tag[len('')]
|
|
ret_message = ret_message.replace(tag, '')
|
|
if ret_message == '':
|
|
return None, prompt
|
|
end_index = ret_message.rfind('![')
|
|
if end_index >= 0:
|
|
return ret_message[:end_index], prompt
|
|
return ret_message, prompt |