Ollamaの追加。一部のパラメータが反映されない問題の修正
parent
fd788f5c62
commit
e9f73621ce
|
|
@ -34,12 +34,20 @@ from langchain.callbacks.manager import AsyncCallbackManager
|
|||
#from langchain_community.llms import OpenAI
|
||||
#os.environ['OPENAI_API_KEY'] = 'foo'
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
def __init__(self):
|
||||
self.recieved_message = ''
|
||||
self.is_cancel = False
|
||||
|
||||
def on_chat_model_start(
|
||||
self, serialized, messages, **kwargs
|
||||
):
|
||||
return
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs) -> None:
|
||||
self.recieved_message += token
|
||||
if self.is_cancel:
|
||||
|
|
@ -156,10 +164,25 @@ class LangChainApi:
|
|||
#verbose=True,
|
||||
)
|
||||
is_chat = False
|
||||
if self.backend == 'Ollama':
|
||||
if (not 'ollama_model' in self.settings) or (self.settings['ollama_model'] is None):
|
||||
return
|
||||
if not 'llama_cpp_n_gpu_layers' in self.settings:
|
||||
self.settings['llama_cpp_n_gpu_layers'] = 20
|
||||
if not 'llama_cpp_n_ctx' in self.settings:
|
||||
self.settings['llama_cpp_n_ctx'] = 2048
|
||||
self.llm = ChatOllama(
|
||||
model=self.settings['ollama_model'],
|
||||
num_gpu=int(self.settings['llama_cpp_n_gpu_layers']),
|
||||
num_ctx=int(self.settings['llama_cpp_n_ctx']),
|
||||
streaming=True,
|
||||
callback_manager=AsyncCallbackManager([self.callback]),
|
||||
#verbose=True,
|
||||
)
|
||||
is_chat = True
|
||||
|
||||
if not is_chat:
|
||||
if 'llama_cpp_system_message_language' in self.settings and self.settings['llama_cpp_system_message_language'] == 'Japanese':
|
||||
system_message = """あなたは人間と会話するチャットボットです。
|
||||
if 'llama_cpp_system_message_language' in self.settings and self.settings['llama_cpp_system_message_language'] == 'Japanese':
|
||||
system_message = """あなたは人間と会話するチャットボットです。
|
||||
|
||||
また、あなたはStable Diffusionで画像を生成する機能があります。
|
||||
その機能を実行する場合は、以下をあなたの返信内容に加えてください。
|
||||
|
|
@ -176,11 +199,9 @@ PROMPTは画像生成に使用するプロンプトに置き換えてくださ
|
|||
この画像は返信メッセージの後に表示されます。
|
||||
このプロンプトが複数存在する場合は、最初のプロンプトの画像のみが生成されます。
|
||||
この画像生成機能に、記憶する機能は無いので、過去の会話内容も反映させてください。
|
||||
<|end_of_turn|>
|
||||
この内容を理解したら、以下の内容に返事をしてください:<|end_of_turn|>
|
||||
"""
|
||||
else:
|
||||
system_message = """You are a chatbot having a conversation with a human.
|
||||
else:
|
||||
system_message = """You are a chatbot having a conversation with a human.
|
||||
|
||||
You also have the function to generate image with Stable Diffusion.
|
||||
If you want to use this function, please add the following to your message.
|
||||
|
|
@ -197,10 +218,14 @@ For example, if you want to output "a school girl wearing a red ribbon", it woul
|
|||
The image is always output at the end, not at the location where it is added.
|
||||
If there are multiple entries, only the first one will be reflected.
|
||||
There is no memory function, so please carry over the prompts from past conversations.
|
||||
<|end_of_turn|>
|
||||
If you understand, please reply to the following:<|end_of_turn|>
|
||||
"""
|
||||
|
||||
if not is_chat:
|
||||
if 'llama_cpp_system_message_language' in self.settings and self.settings['llama_cpp_system_message_language'] == 'Japanese':
|
||||
system_message += '<|end_of_turn|>\nこの内容を理解したら、以下の内容に返事をしてください:<|end_of_turn|>'
|
||||
else:
|
||||
system_message += '<|end_of_turn|>\nIf you understand, please reply to the following:<|end_of_turn|>'
|
||||
|
||||
if self.backend == 'GPT4All':
|
||||
self.prompt = TemplateMessagesPrompt(
|
||||
system_message=system_message,
|
||||
|
|
@ -215,23 +240,29 @@ If you understand, please reply to the following:<|end_of_turn|>
|
|||
ai_template=ai_template_str,
|
||||
input_variables=['history', 'input'],
|
||||
)
|
||||
else:
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
("system", system_message),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", "{input}"),
|
||||
])
|
||||
|
||||
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
|
||||
self.llm_chain = ConversationChain(prompt=self.prompt, llm=self.llm, memory=self.memory)#, verbose=True)
|
||||
|
||||
def chat_predict(human_input):
|
||||
if self.callback.is_cancel:
|
||||
self.callback.is_cancel = False
|
||||
return None
|
||||
try:
|
||||
ret = self.llm_chain.invoke({
|
||||
'input': human_input,
|
||||
})
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
#print(ret)
|
||||
return ret['response']
|
||||
def chat_predict(human_input):
|
||||
if self.callback.is_cancel:
|
||||
self.callback.is_cancel = False
|
||||
return None
|
||||
try:
|
||||
ret = self.llm_chain.invoke({
|
||||
'input': human_input,
|
||||
})
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
#print(ret)
|
||||
return ret['response']
|
||||
|
||||
self.chat_predict = chat_predict
|
||||
self.chat_predict = chat_predict
|
||||
|
||||
self.is_inited = True
|
||||
|
||||
|
|
|
|||
109
scripts/main.py
109
scripts/main.py
|
|
@ -65,6 +65,35 @@ txt2img_json_default = '''{
|
|||
}
|
||||
'''
|
||||
|
||||
def txt2img_with_params(steps: int, sampler_name: str, id_task: str, request: gr.Request, *args):
|
||||
from modules.txt2img import txt2img_create_processing
|
||||
from contextlib import closing
|
||||
from modules import processing
|
||||
import modules.shared as shared
|
||||
from modules.shared import opts
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
p = txt2img_create_processing(id_task, request, *args)
|
||||
p.steps = steps
|
||||
p.sampler_name = sampler_name
|
||||
|
||||
with closing(p):
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
|
||||
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
if opts.samples_log_stdout:
|
||||
print(generation_info_js)
|
||||
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
def get_path_settings_file(file_name: str, new_file=True):
|
||||
ret = os.path.join(os.path.dirname(__file__), '..', 'settings', file_name)
|
||||
if os.path.isfile(ret):
|
||||
|
|
@ -238,8 +267,10 @@ def on_ui_tabs():
|
|||
else:
|
||||
txt2img_args.append(None)
|
||||
|
||||
images, info_js, info_html, comments_html = txt2img(
|
||||
*txt2img_args)
|
||||
if not 'steps' in txt2img_args_names:
|
||||
images, info_js, info_html, comments_html = txt2img_with_params(txt2img_params['steps'], txt2img_params['sampler_name'], *txt2img_args)
|
||||
else:
|
||||
images, info_js, info_html, comments_html = txt2img(*txt2img_args)
|
||||
last_prompt = txt2img_params['prompt']
|
||||
image_info = json.loads(info_js)
|
||||
last_seed = image_info['seed']
|
||||
|
|
@ -465,6 +496,19 @@ def on_ui_tabs():
|
|||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.change_model(setting_model)
|
||||
btn_chatgpt_model_save.click(fn=chatgpt_model_save, inputs=txt_chatgpt_model)
|
||||
with gr.TabItem('Ollama', id='Ollama') as ollama_tab_item:
|
||||
with gr.Row():
|
||||
ollama_model = gr.Textbox(label='Model Name')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
ollama_n_gpu_layers = gr.Number(label='n_gpu_layers')
|
||||
with gr.Column():
|
||||
ollama_n_ctx = gr.Number(label='n_ctx')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
ollama_system_message_language = gr.Dropdown(value='English', allow_custom_value=False, label='System Message Language', choices=['English', 'Japanese'])
|
||||
with gr.Column():
|
||||
btn_ollama_save = gr.Button(value='Save And Reflect', variant='primary')
|
||||
with gr.TabItem('LlamaCpp', id='LlamaCpp') as llama_cpp_tab_item:
|
||||
with gr.Row():
|
||||
llama_cpp_model_file = gr.Textbox(label='Model File Path (*.gguf)')
|
||||
|
|
@ -487,20 +531,6 @@ def on_ui_tabs():
|
|||
llama_cpp_system_message_language = gr.Dropdown(value='English', allow_custom_value=False, label='System Message Language', choices=['English', 'Japanese'])
|
||||
with gr.Column():
|
||||
btn_llama_cpp_save = gr.Button(value='Save And Reflect', variant='primary')
|
||||
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, full_template: str, human_template: str, ai_template: str, system_message_language: str):
|
||||
chatgpt_settings['llama_cpp_model'] = path
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers
|
||||
chatgpt_settings['llama_cpp_n_batch'] = n_batch
|
||||
chatgpt_settings['llama_cpp_n_ctx'] = n_ctx
|
||||
chatgpt_settings['llama_cpp_full_template'] = full_template
|
||||
chatgpt_settings['llama_cpp_human_template'] = human_template
|
||||
chatgpt_settings['llama_cpp_ai_template'] = ai_template
|
||||
chatgpt_settings['llama_cpp_system_message_language'] = system_message_language
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.load_settings(**chatgpt_settings)
|
||||
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, llama_cpp_system_message_language])
|
||||
with gr.TabItem('GPT4All', id='GPT4All') as gpt4all_tab_item:
|
||||
with gr.Row():
|
||||
gpt4all_model_file = gr.Textbox(label='Model File Path (*.gguf)')
|
||||
|
|
@ -515,12 +545,48 @@ def on_ui_tabs():
|
|||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.load_settings(**chatgpt_settings)
|
||||
btn_gpt4all_save.click(fn=gpt4all_model_save, inputs=[gpt4all_model_file, gpt4all_prompt_template])
|
||||
|
||||
def ollama_save(name: str, n_gpu_layers: int, n_ctx: int, system_message_language: str):
|
||||
chatgpt_settings['ollama_model'] = name
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers
|
||||
chatgpt_settings['llama_cpp_n_ctx'] = n_ctx
|
||||
chatgpt_settings['llama_cpp_system_message_language'] = system_message_language
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.load_settings(**chatgpt_settings)
|
||||
return n_gpu_layers, n_ctx, system_message_language
|
||||
btn_ollama_save.click(fn=ollama_save, inputs=[ollama_model, ollama_n_gpu_layers, ollama_n_ctx, ollama_system_message_language],
|
||||
outputs=[llama_cpp_n_gpu_layers, llama_cpp_n_ctx, llama_cpp_system_message_language])
|
||||
|
||||
def llama_cpp_save(path: str, n_gpu_layers: int, n_batch: int, n_ctx: int, full_template: str, human_template: str, ai_template: str, system_message_language: str):
|
||||
chatgpt_settings['llama_cpp_model'] = path
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'] = n_gpu_layers
|
||||
chatgpt_settings['llama_cpp_n_batch'] = n_batch
|
||||
chatgpt_settings['llama_cpp_n_ctx'] = n_ctx
|
||||
chatgpt_settings['llama_cpp_full_template'] = full_template
|
||||
chatgpt_settings['llama_cpp_human_template'] = human_template
|
||||
chatgpt_settings['llama_cpp_ai_template'] = ai_template
|
||||
chatgpt_settings['llama_cpp_system_message_language'] = system_message_language
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
chat_gpt_api.load_settings(**chatgpt_settings)
|
||||
return n_gpu_layers, n_ctx, system_message_language
|
||||
btn_llama_cpp_save.click(fn=llama_cpp_save, inputs=[llama_cpp_model_file, llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, llama_cpp_system_message_language],
|
||||
outputs=[ollama_n_gpu_layers, ollama_n_ctx, ollama_system_message_language])
|
||||
|
||||
def setting_openai_api_tab_item_select():
|
||||
chatgpt_settings['backend'] = 'OpenAI API'
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
init_or_change_backend(apikey, chatgpt_settings)
|
||||
openai_api_tab_item.select(fn=setting_openai_api_tab_item_select)
|
||||
def setting_ollama_tab_item_select():
|
||||
chatgpt_settings['backend'] = 'Ollama'
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
json.dump(chatgpt_settings, f)
|
||||
init_or_change_backend(apikey, chatgpt_settings)
|
||||
ollama_tab_item.select(fn=setting_ollama_tab_item_select)
|
||||
def setting_llama_cpp_tab_item_select():
|
||||
chatgpt_settings['backend'] = 'LlamaCpp'
|
||||
with open(get_path_settings_file('chatgpt_settings.json'), 'w') as f:
|
||||
|
|
@ -557,6 +623,7 @@ def on_ui_tabs():
|
|||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template,
|
||||
llama_cpp_n_ctx,
|
||||
gpt4all_model_file, btn_gpt4all_save, gpt4all_prompt_template,
|
||||
ollama_model, ollama_n_gpu_layers, ollama_n_ctx, ollama_system_message_language,
|
||||
txt_json_settings, btn_settings_save, btn_settings_reflect]
|
||||
|
||||
btn_generate.click(
|
||||
|
|
@ -648,12 +715,12 @@ def on_ui_tabs():
|
|||
chatgpt_settings['llama_cpp_system_message_language'] = 'English'
|
||||
if not 'gpt4all_prompt_template' in chatgpt_settings:
|
||||
chatgpt_settings['gpt4all_prompt_template'] = 'Human: {prompt}<|end_of_turn|>AI: '
|
||||
|
||||
|
||||
ret = [apikey, chatgpt_settings['model'], json_settings, setting_part_tabs_out, save_file_path,
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_batch'], chatgpt_settings['llama_cpp_n_ctx']]
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_batch'], chatgpt_settings['llama_cpp_n_ctx'],
|
||||
chatgpt_settings['llama_cpp_n_gpu_layers'], chatgpt_settings['llama_cpp_n_ctx']]
|
||||
|
||||
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_full_template', 'llama_cpp_human_template', 'llama_cpp_ai_template', 'llama_cpp_system_message_language', 'gpt4all_prompt_template']:
|
||||
for key in ['llama_cpp_model', 'gpt4all_model', 'llama_cpp_full_template', 'llama_cpp_human_template', 'llama_cpp_ai_template', 'llama_cpp_system_message_language', 'gpt4all_prompt_template', 'llama_cpp_system_message_language', 'ollama_model']:
|
||||
if key in chatgpt_settings:
|
||||
ret.append(chatgpt_settings[key])
|
||||
else:
|
||||
|
|
@ -663,8 +730,10 @@ def on_ui_tabs():
|
|||
|
||||
runner_interface.load(on_load, outputs=[txt_apikey, txt_chatgpt_model, txt_json_settings, setting_part_tabs, txt_file_path,
|
||||
llama_cpp_n_gpu_layers, llama_cpp_n_batch, llama_cpp_n_ctx,
|
||||
ollama_n_gpu_layers, ollama_n_ctx,
|
||||
llama_cpp_model_file, gpt4all_model_file,
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, llama_cpp_system_message_language, gpt4all_prompt_template])
|
||||
llama_cpp_full_template, llama_cpp_human_template, llama_cpp_ai_template, llama_cpp_system_message_language, gpt4all_prompt_template,
|
||||
ollama_system_message_language, ollama_model])
|
||||
|
||||
return [(runner_interface, 'sd-webui-chatgpt', 'chatgpt_interface')]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue