llama.cppのフォーマット指定方法を変更

main
NON906 2024-05-04 16:13:55 +09:00
parent 8a5e9bdd0e
commit 72da35acde
2 changed files with 72 additions and 15 deletions

View File

@ -69,6 +69,32 @@ class TemplateMessagesPrompt(StringPromptTemplate):
return messages
class NewTemplateMessagesPrompt(StringPromptTemplate):
full_template: str = ''
human_template: str = ''
ai_template: str = ''
system_message: str = ''
history_name: str = 'history'
input_name: str = 'input'
def format(self, **kwargs: Any) -> str:
full_template = self.full_template.replace("{system}", self.system_message)
human_template_before, human_template_after = self.human_template.split("{message}")
ai_template_before, ai_template_after = self.ai_template.split("{message}")
input_mes_list = kwargs[self.history_name]
messages = ''
for mes in input_mes_list:
if type(mes) is HumanMessage:
messages += human_template_before + mes.content + human_template_after
elif type(mes) is AIMessage:
messages += ai_template_before + mes.content + ai_template_after
messages += human_template_before + kwargs[self.input_name] + human_template_after + ai_template_before
full_messages = full_template.replace("{messages}", messages)
#print(full_messages)
return full_messages
class LangChainApi:
log_file_name = None
is_sending = False
@ -110,6 +136,14 @@ class LangChainApi:
self.settings['llama_cpp_n_batch'] = 128
if not 'llama_cpp_n_ctx' in self.settings:
self.settings['llama_cpp_n_ctx'] = 2048
full_template_str = self.settings['llama_cpp_full_template']
human_template_str = self.settings['llama_cpp_human_template']
ai_template_str = self.settings['llama_cpp_ai_template']
stop_word = ai_template_str.split("{message}")[-1]
if not stop_word.isspace():
stop_words = [stop_word, ]
else:
stop_words = []
self.llm = LlamaCpp(
model_path=self.settings['llama_cpp_model'],
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
@ -117,10 +151,10 @@ class LangChainApi:
n_ctx=self.settings['llama_cpp_n_ctx'],
streaming=True,
callback_manager=AsyncCallbackManager([self.callback]),
stop=stop_words,
#verbose=True,
)
is_chat = False
prompt_template_str = self.settings['llama_cpp_prompt_template']
if not is_chat:
system_message = """You are a chatbot having a conversation with a human.
@ -144,11 +178,20 @@ There is no memory function, so please carry over the prompts from past conversa
If you understand, please reply to the following:<|end_of_turn|>
"""
self.prompt = TemplateMessagesPrompt(
system_message=system_message,
template=prompt_template_str,
input_variables=['history', 'input'],
)
if self.backend == 'GPT4All':
self.prompt = TemplateMessagesPrompt(
system_message=system_message,
template=prompt_template_str,
input_variables=['history', 'input'],
)
else:
self.prompt = NewTemplateMessagesPrompt(
system_message=system_message,
full_template=full_template_str,
human_template=human_template_str,
ai_template=ai_template_str,
input_variables=['history', 'input'],
)
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)

View File

@ -477,19 +477,27 @@ def on_ui_tabs():
llama_cpp_n_ctx = gr.Number(label='n_ctx')
with gr.Row():
with gr.Column():
llama_cpp_prompt_template = gr.Textbox(label='Prompt Template')
llama_cpp_full_template = gr.Textbox(lines=3, label='Full Template')
with gr.Column():
llama_cpp_human_template = gr.Textbox(lines=3, label='Human Template')
with gr.Column():
llama_cpp_ai_template = gr.Textbox(lines=3, label='AI Template')
with gr.Row():
with gr.Column():
btn_llama_cpp_save = gr.Button(value='Save And Reflect', variant='primary')
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, prompt_template: str):
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, full_template: str, human_template: str, ai_template: str):
chatgpt_settings['llama_cpp_model'] = path
chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers
chatgpt_settings['llama_cpp_n_batch'] = n_batch
chatgpt_settings['llama_cpp_n_ctx'] = n_ctx
chatgpt_settings['llama_cpp_prompt_template'] = prompt_template
chatgpt_settings['llama_cpp_full_template'] = full_template
chatgpt_settings['llama_cpp_human_template'] = human_template
chatgpt_settings['llama_cpp_ai_template'] = ai_template
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
json.dump(chatgpt_settings, f)
chat_gpt_api.load_settings(**chatgpt_settings)
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx, llama_cpp_prompt_template])
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template])
with gr.TabItem('GPT4All', id='GPT4All') as gpt4all_tab_item:
with gr.Row():
gpt4all_model_file = gr.Textbox(label='Model File Path (*.gguf)')
@ -542,7 +550,9 @@ def on_ui_tabs():
set_interactive_items = [text_input, btn_generate, btn_regenerate,
btn_remove_last, btn_clear, btn_load, btn_save,
txt_apikey, btn_apikey_save, txt_chatgpt_model, btn_chatgpt_model_save,
llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save, llama_cpp_prompt_template, llama_cpp_n_ctx,
llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save,
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template,
llama_cpp_n_ctx,
gpt4all_model_file, btn_gpt4all_save, gpt4all_prompt_template,
txt_json_settings, btn_settings_save, btn_settings_reflect]
@ -625,15 +635,19 @@ def on_ui_tabs():
chatgpt_settings['llama_cpp_n_batch'] = 128
if not 'llama_cpp_n_ctx' in chatgpt_settings:
chatgpt_settings['llama_cpp_n_ctx'] = 2048
if not 'llama_cpp_prompt_template' in chatgpt_settings:
chatgpt_settings['llama_cpp_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: '
if not 'llama_cpp_full_template' in chatgpt_settings:
chatgpt_settings['llama_cpp_full_template'] = '{system}\n\n{messages}'
if not 'llama_cpp_human_template' in chatgpt_settings:
chatgpt_settings['llama_cpp_human_template'] = 'Human: {message}<|end_of_turn|>'
if not 'llama_cpp_ai_template' in chatgpt_settings:
chatgpt_settings['llama_cpp_ai_template'] = 'AI: {message}<|end_of_turn|>'
if not 'gpt4all_prompt_template' in chatgpt_settings:
chatgpt_settings['gpt4all_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: '
ret = [apikey, chatgpt_settings['model'], json_settings, setting_part_tabs_out, save_file_path,
chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_batch'], chatgpt_settings['llama_cpp_n_ctx']]
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_prompt_template', 'gpt4all_prompt_template']:
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_full_template', 'llama_cpp_human_template', 'llama_cpp_ai_template', 'gpt4all_prompt_template']:
if key in chatgpt_settings:
ret.append(chatgpt_settings[key])
else:
@ -644,7 +658,7 @@ def on_ui_tabs():
runner_interface.load(on_load, outputs=[txt_apikey, txt_chatgpt_model, txt_json_settings, setting_part_tabs, txt_file_path,
llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
llama_cpp_model_file, gpt4all_model_file,
llama_cpp_prompt_template, gpt4all_prompt_template])
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, gpt4all_prompt_template])
return [(runner_interface, 'sd-webui-chatgpt', 'chatgpt_interface')]