llama.cppのフォーマット指定方法を変更
parent
8a5e9bdd0e
commit
72da35acde
|
|
@ -69,6 +69,32 @@ class TemplateMessagesPrompt(StringPromptTemplate):
|
|||
return messages
|
||||
|
||||
|
||||
class NewTemplateMessagesPrompt(StringPromptTemplate):
|
||||
full_template: str = ''
|
||||
human_template: str = ''
|
||||
ai_template: str = ''
|
||||
system_message: str = ''
|
||||
history_name: str = 'history'
|
||||
input_name: str = 'input'
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
full_template = self.full_template.replace("{system}", self.system_message)
|
||||
human_template_before, human_template_after = self.human_template.split("{message}")
|
||||
ai_template_before, ai_template_after = self.ai_template.split("{message}")
|
||||
|
||||
input_mes_list = kwargs[self.history_name]
|
||||
messages = ''
|
||||
for mes in input_mes_list:
|
||||
if type(mes) is HumanMessage:
|
||||
messages += human_template_before + mes.content + human_template_after
|
||||
elif type(mes) is AIMessage:
|
||||
messages += ai_template_before + mes.content + ai_template_after
|
||||
messages += human_template_before + kwargs[self.input_name] + human_template_after + ai_template_before
|
||||
full_messages = full_template.replace("{messages}", messages)
|
||||
#print(full_messages)
|
||||
return full_messages
|
||||
|
||||
|
||||
class LangChainApi:
|
||||
log_file_name = None
|
||||
is_sending = False
|
||||
|
|
@ -110,6 +136,14 @@ class LangChainApi:
|
|||
self.settings['llama_cpp_n_batch'] = 128
|
||||
if not 'llama_cpp_n_ctx' in self.settings:
|
||||
self.settings['llama_cpp_n_ctx'] = 2048
|
||||
full_template_str = self.settings['llama_cpp_full_template']
|
||||
human_template_str = self.settings['llama_cpp_human_template']
|
||||
ai_template_str = self.settings['llama_cpp_ai_template']
|
||||
stop_word = ai_template_str.split("{message}")[-1]
|
||||
if not stop_word.isspace():
|
||||
stop_words = [stop_word, ]
|
||||
else:
|
||||
stop_words = []
|
||||
self.llm = LlamaCpp(
|
||||
model_path=self.settings['llama_cpp_model'],
|
||||
n_gpu_layers=self.settings['llama_cpp_n_gpu_layers'],
|
||||
|
|
@ -117,10 +151,10 @@ class LangChainApi:
|
|||
n_ctx=self.settings['llama_cpp_n_ctx'],
|
||||
streaming=True,
|
||||
callback_manager=AsyncCallbackManager([self.callback]),
|
||||
stop=stop_words,
|
||||
#verbose=True,
|
||||
)
|
||||
is_chat = False
|
||||
prompt_template_str = self.settings['llama_cpp_prompt_template']
|
||||
|
||||
if not is_chat:
|
||||
system_message = """You are a chatbot having a conversation with a human.
|
||||
|
|
@ -144,11 +178,20 @@ There is no memory function, so please carry over the prompts from past conversa
|
|||
If you understand, please reply to the following:<|end_of_turn|>
|
||||
"""
|
||||
|
||||
self.prompt = TemplateMessagesPrompt(
|
||||
system_message=system_message,
|
||||
template=prompt_template_str,
|
||||
input_variables=['history', 'input'],
|
||||
)
|
||||
if self.backend == 'GPT4All':
|
||||
self.prompt = TemplateMessagesPrompt(
|
||||
system_message=system_message,
|
||||
template=prompt_template_str,
|
||||
input_variables=['history', 'input'],
|
||||
)
|
||||
else:
|
||||
self.prompt = NewTemplateMessagesPrompt(
|
||||
system_message=system_message,
|
||||
full_template=full_template_str,
|
||||
human_template=human_template_str,
|
||||
ai_template=ai_template_str,
|
||||
input_variables=['history', 'input'],
|
||||
)
|
||||
|
||||
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -477,19 +477,27 @@ def on_ui_tabs():
|
|||
llama_cpp_n_ctx = gr.Number(label='n_ctx')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
llama_cpp_prompt_template = gr.Textbox(label='Prompt Template')
|
||||
llama_cpp_full_template = gr.Textbox(lines=3, label='Full Template')
|
||||
with gr.Column():
|
||||
llama_cpp_human_template = gr.Textbox(lines=3, label='Human Template')
|
||||
with gr.Column():
|
||||
llama_cpp_ai_template = gr.Textbox(lines=3, label='AI Template')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
btn_llama_cpp_save = gr.Button(value='Save And Reflect', variant='primary')
|
||||
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, prompt_template: str):
|
||||
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, full_template: str, human_template: str, ai_template: str):
|
||||
chatgpt_settings['llama_cpp_model'] = path
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers
|
||||
chatgpt_settings['llama_cpp_n_batch'] = n_batch
|
||||
chatgpt_settings['llama_cpp_n_ctx'] = n_ctx
|
||||
chatgpt_settings['llama_cpp_prompt_template'] = prompt_template
|
||||
chatgpt_settings['llama_cpp_full_template'] = full_template
|
||||
chatgpt_settings['llama_cpp_human_template'] = human_template
|
||||
chatgpt_settings['llama_cpp_ai_template'] = ai_template
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.load_settings(**chatgpt_settings)
|
||||
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx, llama_cpp_prompt_template])
|
||||
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template])
|
||||
with gr.TabItem('GPT4All', id='GPT4All') as gpt4all_tab_item:
|
||||
with gr.Row():
|
||||
gpt4all_model_file = gr.Textbox(label='Model File Path (*.gguf)')
|
||||
|
|
@ -542,7 +550,9 @@ def on_ui_tabs():
|
|||
set_interactive_items = [text_input, btn_generate, btn_regenerate,
|
||||
btn_remove_last, btn_clear, btn_load, btn_save,
|
||||
txt_apikey, btn_apikey_save, txt_chatgpt_model, btn_chatgpt_model_save,
|
||||
llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save, llama_cpp_prompt_template, llama_cpp_n_ctx,
|
||||
llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, btn_llama_cpp_save,
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template,
|
||||
llama_cpp_n_ctx,
|
||||
gpt4all_model_file, btn_gpt4all_save, gpt4all_prompt_template,
|
||||
txt_json_settings, btn_settings_save, btn_settings_reflect]
|
||||
|
||||
|
|
@ -625,15 +635,19 @@ def on_ui_tabs():
|
|||
chatgpt_settings['llama_cpp_n_batch'] = 128
|
||||
if not 'llama_cpp_n_ctx' in chatgpt_settings:
|
||||
chatgpt_settings['llama_cpp_n_ctx'] = 2048
|
||||
if not 'llama_cpp_prompt_template' in chatgpt_settings:
|
||||
chatgpt_settings['llama_cpp_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: '
|
||||
if not 'llama_cpp_full_template' in chatgpt_settings:
|
||||
chatgpt_settings['llama_cpp_full_template'] = '{system}\n\n{messages}'
|
||||
if not 'llama_cpp_human_template' in chatgpt_settings:
|
||||
chatgpt_settings['llama_cpp_human_template'] = 'Human: {message}<|end_of_turn|>'
|
||||
if not 'llama_cpp_ai_template' in chatgpt_settings:
|
||||
chatgpt_settings['llama_cpp_ai_template'] = 'AI: {message}<|end_of_turn|>'
|
||||
if not 'gpt4all_prompt_template' in chatgpt_settings:
|
||||
chatgpt_settings['gpt4all_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: '
|
||||
|
||||
ret = [apikey, chatgpt_settings['model'], json_settings, setting_part_tabs_out, save_file_path,
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_batch'], chatgpt_settings['llama_cpp_n_ctx']]
|
||||
|
||||
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_prompt_template', 'gpt4all_prompt_template']:
|
||||
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_full_template', 'llama_cpp_human_template', 'llama_cpp_ai_template', 'gpt4all_prompt_template']:
|
||||
if key in chatgpt_settings:
|
||||
ret.append(chatgpt_settings[key])
|
||||
else:
|
||||
|
|
@ -644,7 +658,7 @@ def on_ui_tabs():
|
|||
runner_interface.load(on_load, outputs=[txt_apikey, txt_chatgpt_model, txt_json_settings, setting_part_tabs, txt_file_path,
|
||||
llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
|
||||
llama_cpp_model_file, gpt4all_model_file,
|
||||
llama_cpp_prompt_template, gpt4all_prompt_template])
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, gpt4all_prompt_template])
|
||||
|
||||
return [(runner_interface, 'sd-webui-chatgpt', 'chatgpt_interface')]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue