feature: per-worker model override

pull/17/head
unknown 2023-07-15 05:45:55 -05:00
parent 9d48e2732d
commit bd5e6ac7b7
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 75 additions and 13 deletions

View File

@ -1,6 +1,7 @@
import io import io
import os import os
import subprocess import subprocess
import threading
from pathlib import Path from pathlib import Path
import gradio import gradio
from scripts.spartan.shared import logger, log_level from scripts.spartan.shared import logger, log_level
@ -8,6 +9,7 @@ from scripts.spartan.Worker import State, Worker
from modules.shared import state as webui_state from modules.shared import state as webui_state
import json import json
from typing import List from typing import List
from threading import Thread
worker_select_dropdown = None worker_select_dropdown = None
@ -113,7 +115,15 @@ class UI:
if worker.uuid == selection: if worker.uuid == selection:
selected_worker = worker selected_worker = worker
return [gradio.Textbox.update(value=selected_worker.address), gradio.Textbox.update(value=selected_worker.port), gradio.Checkbox.update(value=selected_worker.tls)] avail_models = selected_worker.available_models()
avail_models.append('None') # for disabling override
return [
gradio.Textbox.update(value=selected_worker.address),
gradio.Textbox.update(value=selected_worker.port),
gradio.Checkbox.update(value=selected_worker.tls),
gradio.Dropdown.update(choices=avail_models)
]
def reconnect_remotes(self): def reconnect_remotes(self):
@ -127,6 +137,17 @@ class UI:
else: else:
logger.info(f"worker '{worker.uuid}' is still unreachable") logger.info(f"worker '{worker.uuid}' is still unreachable")
def override_worker_model(self, model, worker_label):
worker = self.world.worker_from_label(worker_label)
if model == "None":
worker.model_override = None
else:
worker.model_override = model
# set model on remote early
Thread(target=worker.load_options, args=(model,)).start()
# end handlers # end handlers
@ -176,8 +197,6 @@ class UI:
clear_queue_btn.click(self.clear_queue_btn) clear_queue_btn.click(self.clear_queue_btn)
with gradio.Tab('Worker Config'): with gradio.Tab('Worker Config'):
worker_select_dropdown = None
worker_select_dropdown = gradio.Dropdown( worker_select_dropdown = gradio.Dropdown(
[x.uuid for x in self.selectable_remote_workers()], [x.uuid for x in self.selectable_remote_workers()],
info='Select a pre-existing worker or enter a label for a new one', info='Select a pre-existing worker or enter a label for a new one',
@ -189,11 +208,10 @@ class UI:
worker_tls_cbx = gradio.Checkbox( worker_tls_cbx = gradio.Checkbox(
label='connect using https' label='connect using https'
) )
worker_select_dropdown.select(
self.populate_worker_config_from_selection, with gradio.Accordion(label='Advanced'):
inputs=worker_select_dropdown, model_override_dropdown = gradio.Dropdown(label='Model override')
outputs=[worker_address_field, worker_port_field, worker_tls_cbx] model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown])
)
with gradio.Row(): with gradio.Row():
save_worker_btn = gradio.Button(value='Add/Update Worker') save_worker_btn = gradio.Button(value='Add/Update Worker')
@ -201,6 +219,12 @@ class UI:
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop') remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown]) remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
worker_select_dropdown.select(
self.populate_worker_config_from_selection,
inputs=worker_select_dropdown,
outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown]
)
with gradio.Tab('Settings'): with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox( thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)', label='Thin-client mode (experimental)',

View File

@ -18,6 +18,7 @@ import queue
from modules.shared import state as master_state from modules.shared import state as master_state
from modules.api.api import encode_pil_to_base64 from modules.api.api import encode_pil_to_base64
import scripts.spartan.shared as sh import scripts.spartan.shared as sh
import re
class InvalidWorkerResponse(Exception): class InvalidWorkerResponse(Exception):
@ -114,6 +115,7 @@ class Worker:
self.loaded_vae = '' self.loaded_vae = ''
self.state = State.IDLE self.state = State.IDLE
self.tls = tls self.tls = tls
self.model_override: str = None
if uuid is not None: if uuid is not None:
self.uuid = uuid self.uuid = uuid
@ -278,11 +280,11 @@ class Worker:
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}") logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
if sync_options is True: if sync_options is True:
options_response = requests.post( model = option_payload['sd_model_checkpoint']
self.full_url("options"), if self.model_override is not None:
json=option_payload, model = self.model_override
verify=self.verify_remotes
) self.load_options(model=model, vae=option_payload['sd_vae'])
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a # TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
# second GET to see if everything loaded... # second GET to see if everything loaded...
@ -517,3 +519,34 @@ class Worker:
def mark_unreachable(self): def mark_unreachable(self):
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future") logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
self.state = State.UNAVAILABLE self.state = State.UNAVAILABLE
def available_models(self) -> List[str]:
response = requests.get(
url=self.full_url('sd-models'),
verify=self.verify_remotes
)
titles = [model['title'] for model in response.json()]
return titles
def load_options(self, model, vae=None):
model_name = re.sub(r'\s?\[[^\]]*\]$', '', model)
payload = {
"sd_model_checkpoint": model_name
}
if vae is not None:
payload['sd_vae'] = vae
response = requests.post(
self.full_url("options"),
json=payload,
verify=self.verify_remotes
)
if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.uuid}'")
else:
self.loaded_model = model_name
if vae is not None:
self.loaded_vae = vae

View File

@ -526,3 +526,8 @@ class World:
json.dump(config, config_file, indent=3) json.dump(config, config_file, indent=3)
logger.debug(f"config saved") logger.debug(f"config saved")
def worker_from_label(self, label: str) -> Worker:
for worker in self._workers:
if worker.uuid == label:
return worker