会話ログの統一

langchain
NON906 2024-01-10 13:40:29 +09:00
parent 741b837616
commit 9db4b1477b
2 changed files with 40 additions and 12 deletions

View File

@ -146,7 +146,18 @@ If you understand, please reply to the following:<|end_of_turn|>
if mes['role'] == 'user':
history.add_user_message(mes['content'])
elif mes['role'] == 'assistant':
history.add_ai_message(mes['content'])
if '\n(Generated image by the following prompt: ' in mes['content']:
mes_content, mes_prompt = mes['content'].split('\n(Generated image by the following prompt: ')
mes_prompt = mes_prompt[::-1].replace(')', '', 1)[::-1]
mes_json = json.dumps({
"prompt": mes_prompt,
"message": mes_content,
})
else:
mes_json = json.dumps({
"message": mes['content'],
})
history.add_ai_message(mes_json)
self.memory.chat_memory = history
elif chatgpt_messages['log_version'] == 2:
self.memory.chat_memory = messages_from_dict(chatgpt_messages['messages'])
@ -165,12 +176,23 @@ If you understand, please reply to the following:<|end_of_turn|>
return False
def get_log(self):
dicts = {'log_version': 2}
if self.memory is None:
dicts['messages'] = {}
else:
dicts['messages'] = messages_to_dict(self.memory.chat_memory)
return json.dumps(dicts)
ret_messages = []
for mes in self.memory.chat_memory:
if type(mes) is HumanMessage:
ret_messages.append({"role": "user", "content": mes.content})
elif type(mes) is AIMessage:
mes_dict = json.loads(mes.content)
add_mes = mes_dict['message']
if 'prompt' in mes_dict and mes_dict['prompt'] is not None and mes_dict['prompt'] != "":
add_mes += "\n(Generated image by the following prompt: " + mes_dict['prompt'] + ")"
ret_messages.append({"role": "assistant", "content": add_mes})
return json.dumps(ret_messages)
#dicts = {'log_version': 2}
#if self.memory is None:
# dicts['messages'] = {}
#else:
# dicts['messages'] = messages_to_dict(self.memory.chat_memory)
#return json.dumps(dicts)
def write_log(self, file_name=None):
if file_name is None:
@ -211,7 +233,10 @@ If you understand, please reply to the following:<|end_of_turn|>
def remove_last_conversation(self, result=None, write_log=False):
if result is None or self.memory.chat_memory.messages[-1].content == result:
self.memory.chat_memory.messages = self.memory.chat_memory.messages[:-2]
if len(self.memory.chat_memory.messages) > 2:
self.memory.chat_memory.messages = self.memory.chat_memory.messages[:-2]
else:
self.memory.chat_memory.messages.clear()
if write_log:
self.write_log()

View File

@ -102,6 +102,10 @@ def init_txt2img_params():
def init_or_change_backend(apikey, chatgpt_settings):
global chat_gpt_api
log = None
if chat_gpt_api is not None:
log = chat_gpt_api.get_log()
if chatgpt_settings['backend'] == 'OpenAI API':
if type(chat_gpt_api) is chatgptapi.ChatGptApi:
chat_gpt_api.change_apikey(apikey)
@ -114,13 +118,12 @@ def init_or_change_backend(apikey, chatgpt_settings):
else:
if type(chat_gpt_api) is langchainapi.LangChainApi:
chat_gpt_api.load_settings(**chatgpt_settings)
elif chat_gpt_api is not None:
log = chat_gpt_api.get_log()
chat_gpt_api = langchainapi.LangChainApi(**chatgpt_settings)
chat_gpt_api.set_log(log)
else:
chat_gpt_api = langchainapi.LangChainApi(**chatgpt_settings)
if log is not None:
chat_gpt_api.set_log(log)
def on_ui_tabs():
global txt2img_params_base, public_ui, public_ui_value, chat_gpt_api