sd-webui-chatgpt/scripts/langchainapi.py

284 lines
10 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import json
from gpt_stream_parser import force_parse_json
from langchain_community.llms import GPT4All, LlamaCpp
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
from langchain.schema import (
AIMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from pydantic import BaseModel, Field
from typing import Optional
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import AsyncCallbackManager
#from langchain_community.llms import OpenAI
#os.environ['OPENAI_API_KEY'] = 'foo'
class Txt2ImgModel(BaseModel):
prompt: Optional[str] = Field(description='''Prompt for generate image.
Generate image from prompt by Stable Diffusion. (Sentences cannot be generated.)
There is no memory function, so please carry over the prompts from past conversations.
Prompt is comma separated keywords such as "1girl, school uniform, red ribbon" (not list).
If it is not in English, please translate it into English (lang:en).''')
message: str = Field(description='''Chat message.
Please enter the content of your reply to me.
If prompt is exists, Displayed before the image.''')
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
def __init__(self):
self.recieved_message = ''
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.recieved_message += token
class LangChainApi:
log_file_name = None
is_sending = False
def __init__(self, **kwargs):
self.backend = None
self.memory = ConversationBufferMemory(
human_prefix="Human",
ai_prefix="AI",
memory_key="history",
return_messages=True,
)
self.callback = StreamingLLMCallbackHandler()
self.load_settings(**kwargs)
def init_model(self):
if self.backend is None:
return
if self.backend == 'GPT4All':
if (not 'gpt4all_model' in self.settings) or (self.settings['gpt4all_model'] is None):
return
self.llm = GPT4All(
model=self.settings['gpt4all_model'],
streaming=True,
callback_manager=AsyncCallbackManager([self.callback]),
)
#self.llm = OpenAI(model_name="gpt-3.5-turbo")
is_chat = False
if self.backend == 'LlamaCpp':
if (not 'llama_cpp_model' in self.settings) or (self.settings['llama_cpp_model'] is None):
return
if not 'llama_cpp_n_gpu_layers' in self.settings:
self.settings['llama_cpp_n_gpu_layers'] = 20
if not 'llama_cpp_n_batch' in self.settings:
self.settings['llama_cpp_n_batch'] = 128
self.llm = LlamaCpp(
model_path=self.settings['llama_cpp_model'],
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
n_batch=self.settings['llama_cpp_n_batch'],
n_ctx=2048,
streaming=True,
callback_manager=AsyncCallbackManager([self.callback]),
#verbose=True,
)
is_chat = False
self.pydantic_parser = PydanticOutputParser(pydantic_object=Txt2ImgModel)
if not is_chat:
template = """
You are a chatbot having a conversation with a human.
You also have the function to generate image with Stable Diffusion.
{format_instructions}
Below is an example of the final output:
```
{{
"prompt": "1girl, school uniform, red ribbon",
"message": "This is a school girl wearing a red ribbon."
}}
```
<|end_of_turn|>
If you understand, please reply to the following:<|end_of_turn|>
"""
format_instructions = self.pydantic_parser.get_format_instructions()
system_message_prompt = SystemMessagePromptTemplate.from_template(
template,
partial_variables={"format_instructions": format_instructions},
)
human_template = "{human_input}<|end_of_turn|>AI:"
human_message_prompt = HumanMessagePromptTemplate.from_template(
human_template,
)
self.prompt = ChatPromptTemplate.from_messages([
system_message_prompt,
MessagesPlaceholder(variable_name="history"),
human_message_prompt,
])
self.llm_chain = LLMChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
def chat_predict(human_input):
ret = self.llm_chain.invoke({
'human_input': human_input,
})['text']
#print(ret)
return ret
self.chat_predict = chat_predict
self.parser = OutputFixingParser.from_llm(
parser=self.pydantic_parser,
llm=self.llm,
)
self.is_inited = True
def load_settings(self, **kwargs):
self.settings = kwargs
self.backend = self.settings['backend']
self.is_inited = False
def set_log(self, log_string):
chatgpt_messages = json.loads(log_string)
if not 'log_version' in chatgpt_messages:
history = ChatMessageHistory()
for mes in chatgpt_messages:
if mes['role'] == 'user':
history.add_user_message(mes['content'])
elif mes['role'] == 'assistant':
if '\n(Generated image by the following prompt: ' in mes['content']:
mes_content, mes_prompt = mes['content'].split('\n(Generated image by the following prompt: ')
mes_prompt = mes_prompt[::-1].replace(')', '', 1)[::-1]
mes_json = json.dumps({
"prompt": mes_prompt,
"message": mes_content,
})
else:
mes_json = json.dumps({
"message": mes['content'],
})
history.add_ai_message(mes_json)
self.memory.chat_memory = history
elif chatgpt_messages['log_version'] == 2:
self.memory.chat_memory = messages_from_dict(chatgpt_messages['messages'])
def load_log(self, log):
if log is None:
return False
try:
self.log_file_name = log
if os.path.isfile(log):
with open(log, 'r', encoding='UTF-8') as f:
self.set_log(f.read())
return True
except:
pass
return False
def get_log(self):
if self.memory is None:
return '[]'
ret_messages = []
for mes in self.memory.chat_memory:
if type(mes) is HumanMessage:
ret_messages.append({"role": "user", "content": mes.content})
elif type(mes) is AIMessage:
mes_dict = json.loads(mes.content)
add_mes = mes_dict['message']
if 'prompt' in mes_dict and mes_dict['prompt'] is not None and mes_dict['prompt'] != "":
add_mes += "\n(Generated image by the following prompt: " + mes_dict['prompt'] + ")"
ret_messages.append({"role": "assistant", "content": add_mes})
return json.dumps(ret_messages)
#dicts = {'log_version': 2}
#if self.memory is None:
# dicts['messages'] = {}
#else:
# dicts['messages'] = messages_to_dict(self.memory.chat_memory)
#return json.dumps(dicts)
def write_log(self, file_name=None):
if file_name is None:
file_name = self.log_file_name
if file_name is None:
return
with open(file_name + '.tmp', 'w', encoding='UTF-8') as f:
f.write(self.get_log())
if os.path.isfile(file_name):
os.rename(file_name, file_name + '.prev')
os.rename(file_name + '.tmp', file_name)
if os.path.isfile(file_name + '.prev'):
os.remove(file_name + '.prev')
def send(self, content, write_log=False):
if not self.is_inited:
self.init_model()
if self.is_sending:
return
self.is_sending = True
result = self.chat_predict(human_input=content)
try:
parse_result = self.parser.parse(result)
return_message = parse_result.message
return_prompt = parse_result.prompt
except:
return_message = result
return_prompt = None
if write_log:
self.write_log()
self.is_sending = False
return return_message, return_prompt
def remove_last_conversation(self, result=None, write_log=False):
if result is None or self.memory.chat_memory.messages[-1].content == result:
if len(self.memory.chat_memory.messages) > 2:
self.memory.chat_memory.messages = self.memory.chat_memory.messages[:-2]
else:
self.memory.chat_memory.messages.clear()
if write_log:
self.write_log()
def clear(self):
self.memory.chat_memory.clear()
self.log_file_name = None
def get_stream(self):
if self.callback is None:
return None, None
if '{' in self.callback.recieved_message:
if '}' in self.callback.recieved_message:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):self.callback.recieved_message.rfind('}') + 1]
else:
recieved_json = self.callback.recieved_message[self.callback.recieved_message.find('{'):]
recieved_dict = force_parse_json(recieved_json)
if recieved_dict is not None and "message" in recieved_dict:
if "prompt" in recieved_dict:
return recieved_dict["message"], recieved_dict["prompt"]
else:
return recieved_dict["message"], None
else:
return None, None
else:
return None, None