diff --git a/scripts/langchainapi.py b/scripts/langchainapi.py index e14826e..f169d42 100644 --- a/scripts/langchainapi.py +++ b/scripts/langchainapi.py @@ -69,6 +69,32 @@ class TemplateMessagesPrompt(StringPromptTemplate): return messages +class NewTemplateMessagesPrompt(StringPromptTemplate): + full_template: str = '' + human_template: str = '' + ai_template: str = '' + system_message: str = '' + history_name: str = 'history' + input_name: str = 'input' + + def format(self, **kwargs: Any) -> str: + full_template = self.full_template.replace("{system}", self.system_message) + human_template_before, human_template_after = self.human_template.split("{message}") + ai_template_before, ai_template_after = self.ai_template.split("{message}") + + input_mes_list = kwargs[self.history_name] + messages = '' + for mes in input_mes_list: + if type(mes) is HumanMessage: + messages += human_template_before + mes.content + human_template_after + elif type(mes) is AIMessage: + messages += ai_template_before + mes.content + ai_template_after + messages += human_template_before + kwargs[self.input_name] + human_template_after + ai_template_before + full_messages = full_template.replace("{messages}", messages) + #print(full_messages) + return full_messages + + class LangChainApi: log_file_name = None is_sending = False @@ -110,6 +136,14 @@ class LangChainApi: self.settings['llama_cpp_n_batch'] = 128 if not 'llama_cpp_n_ctx' in self.settings: self.settings['llama_cpp_n_ctx'] = 2048 + full_template_str = self.settings['llama_cpp_full_template'] + human_template_str = self.settings['llama_cpp_human_template'] + ai_template_str = self.settings['llama_cpp_ai_template'] + stop_word = ai_template_str.split("{message}")[-1] + if not stop_word.isspace(): + stop_words = [stop_word, ] + else: + stop_words = [] self.llm = LlamaCpp( model_path=self.settings['llama_cpp_model'], n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'], @@ -117,10 +151,10 @@ class LangChainApi: n_ctx=self.settings['llama_cpp_n_ctx'], streaming=True, callback_manager=AsyncCallbackManager([self.callback]), + stop=stop_words, #verbose=True, ) is_chat = False - prompt_template_str = self.settings['llama_cpp_prompt_template'] if not is_chat: system_message = """You are a chatbot having a conversation with a human. @@ -144,11 +178,20 @@ There is no memory function, so please carry over the prompts from past conversa If you understand, please reply to the following:<|end_of_turn|> """ - self.prompt = TemplateMessagesPrompt( - system_message=system_message, - template=prompt_template_str, - input_variables=['history', 'input'], - ) + if self.backend == 'GPT4All': + self.prompt = TemplateMessagesPrompt( + system_message=system_message, + template=prompt_template_str, + input_variables=['history', 'input'], + ) + else: + self.prompt = NewTemplateMessagesPrompt( + system_message=system_message, + full_template=full_template_str, + human_template=human_template_str, + ai_template=ai_template_str, + input_variables=['history', 'input'], + ) self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True) diff --git a/scripts/main.py b/scripts/main.py index d5a0fdc..f9a8f47 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -477,19 +477,27 @@ def on_ui_tabs(): llama_cpp_n_ctx = gr.Number(label='n_ctx') with gr.Row(): with gr.Column(): - llama_cpp_prompt_template = gr.Textbox(label='Prompt Template') + llama_cpp_full_template = gr.Textbox(lines=3, label='Full Template') + with gr.Column(): + llama_cpp_human_template = gr.Textbox(lines=3, label='Human Template') + with gr.Column(): + llama_cpp_ai_template = gr.Textbox(lines=3, label='AI Template') + with gr.Row(): with gr.Column(): btn_llama_cpp_save = gr.Button(value='Save And Reflect', variant='primary') - def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, prompt_template: str): + def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, full_template: str, human_template: str, ai_template: str): chatgpt_settings['llama_cpp_model'] = path chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers chatgpt_settings['llama_cpp_n_batch'] = n_batch chatgpt_settings['llama_cpp_n_ctx'] = n_ctx - chatgpt_settings['llama_cpp_prompt_template'] = prompt_template + chatgpt_settings['llama_cpp_full_template'] = full_template + chatgpt_settings['llama_cpp_human_template'] = human_template + chatgpt_settings['llama_cpp_ai_template'] = ai_template with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f: json.dump(chatgpt_settings, f) chat_gpt_api.load_settings(**chatgpt_settings) - btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx, llama_cpp_prompt_template]) + btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx, + llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template]) with gr.TabItem('GPT4All', id='GPT4All') as gpt4all_tab_item: with gr.Row(): gpt4all_model_file = gr.Textbox(label='Model File Path (*.gguf)') @@ -542,7 +550,9 @@ def on_ui_tabs(): set_interactive_items = [text_input, btn_generate, btn_regenerate, btn_remove_last, btn_clear, btn_load, btn_save, txt_apikey, btn_apikey_save, txt_chatgpt_model, btn_chatgpt_model_save, - llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save, llama_cpp_prompt_template, llama_cpp_n_ctx, + llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save, + llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, + llama_cpp_n_ctx, gpt4all_model_file, btn_gpt4all_save, gpt4all_prompt_template, txt_json_settings, btn_settings_save, btn_settings_reflect] @@ -625,15 +635,19 @@ def on_ui_tabs(): chatgpt_settings['llama_cpp_n_batch'] = 128 if not 'llama_cpp_n_ctx' in chatgpt_settings: chatgpt_settings['llama_cpp_n_ctx'] = 2048 - if not 'llama_cpp_prompt_template' in chatgpt_settings: - chatgpt_settings['llama_cpp_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: ' + if not 'llama_cpp_full_template' in chatgpt_settings: + chatgpt_settings['llama_cpp_full_template'] = '{system}\n\n{messages}' + if not 'llama_cpp_human_template' in chatgpt_settings: + chatgpt_settings['llama_cpp_human_template'] = 'Human: {message}<|end_of_turn|>' + if not 'llama_cpp_ai_template' in chatgpt_settings: + chatgpt_settings['llama_cpp_ai_template'] = 'AI: {message}<|end_of_turn|>' if not 'gpt4all_prompt_template' in chatgpt_settings: chatgpt_settings['gpt4all_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: ' ret = [apikey, chatgpt_settings['model'], json_settings, setting_part_tabs_out, save_file_path, chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_batch'], chatgpt_settings['llama_cpp_n_ctx']] - for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_prompt_template', 'gpt4all_prompt_template']: + for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_full_template', 'llama_cpp_human_template', 'llama_cpp_ai_template', 'gpt4all_prompt_template']: if key in chatgpt_settings: ret.append(chatgpt_settings[key]) else: @@ -644,7 +658,7 @@ def on_ui_tabs(): runner_interface.load(on_load, outputs=[txt_apikey, txt_chatgpt_model, txt_json_settings, setting_part_tabs, txt_file_path, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx, llama_cpp_model_file, gpt4all_model_file, - llama_cpp_prompt_template, gpt4all_prompt_template]) + llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, gpt4all_prompt_template]) return [(runner_interface, 'sd-webui-chatgpt', 'chatgpt_interface')]