diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index 2be0103..9c15dd8 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -285,7 +285,7 @@ If you understand, please reply to the following:<|end_of_turn|> if result is None: return_message, return_prompt = None, None else: - return_message, return_prompt = self.parse_message(result) + return_message, return_prompt = self.parse_message(result, True) self.is_sending = False self.callback.recieved_message = '' @@ -313,16 +313,19 @@ If you understand, please reply to the following:<|end_of_turn|> return_message = None return return_message, return_prompt - def parse_message(self, full_message): + def parse_message(self, full_message, is_finished=False): if not '![' in full_message: return full_message, None prompt_tags = re.findall('\!\[.*?\]\(.*?\)', full_message) if len(prompt_tags) <= 0: if '![sd-prompt: ' in full_message: + parted_prompt = full_message.split('![sd-prompt: ')[1] full_message = full_message.replace('![sd-prompt: ', '_').split(']')[0] + '_' end_index = full_message.rfind('![') if end_index >= 0: return full_message[:end_index], None + if is_finished: + return full_message, parted_prompt return full_message, None ret_message = full_message prompt = None