feature: per-worker model override
parent
9d48e2732d
commit
bd5e6ac7b7
|
|
@ -1,6 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import gradio
|
||||
from scripts.spartan.shared import logger, log_level
|
||||
|
|
@ -8,6 +9,7 @@ from scripts.spartan.Worker import State, Worker
|
|||
from modules.shared import state as webui_state
|
||||
import json
|
||||
from typing import List
|
||||
from threading import Thread
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
|
|
@ -113,7 +115,15 @@ class UI:
|
|||
if worker.uuid == selection:
|
||||
selected_worker = worker
|
||||
|
||||
return [gradio.Textbox.update(value=selected_worker.address), gradio.Textbox.update(value=selected_worker.port), gradio.Checkbox.update(value=selected_worker.tls)]
|
||||
avail_models = selected_worker.available_models()
|
||||
avail_models.append('None') # for disabling override
|
||||
|
||||
return [
|
||||
gradio.Textbox.update(value=selected_worker.address),
|
||||
gradio.Textbox.update(value=selected_worker.port),
|
||||
gradio.Checkbox.update(value=selected_worker.tls),
|
||||
gradio.Dropdown.update(choices=avail_models)
|
||||
]
|
||||
|
||||
|
||||
def reconnect_remotes(self):
|
||||
|
|
@ -127,6 +137,17 @@ class UI:
|
|||
else:
|
||||
logger.info(f"worker '{worker.uuid}' is still unreachable")
|
||||
|
||||
def override_worker_model(self, model, worker_label):
|
||||
worker = self.world.worker_from_label(worker_label)
|
||||
|
||||
if model == "None":
|
||||
worker.model_override = None
|
||||
else:
|
||||
worker.model_override = model
|
||||
|
||||
# set model on remote early
|
||||
Thread(target=worker.load_options, args=(model,)).start()
|
||||
|
||||
|
||||
# end handlers
|
||||
|
||||
|
|
@ -176,8 +197,6 @@ class UI:
|
|||
clear_queue_btn.click(self.clear_queue_btn)
|
||||
|
||||
with gradio.Tab('Worker Config'):
|
||||
worker_select_dropdown = None
|
||||
|
||||
worker_select_dropdown = gradio.Dropdown(
|
||||
[x.uuid for x in self.selectable_remote_workers()],
|
||||
info='Select a pre-existing worker or enter a label for a new one',
|
||||
|
|
@ -189,11 +208,10 @@ class UI:
|
|||
worker_tls_cbx = gradio.Checkbox(
|
||||
label='connect using https'
|
||||
)
|
||||
worker_select_dropdown.select(
|
||||
self.populate_worker_config_from_selection,
|
||||
inputs=worker_select_dropdown,
|
||||
outputs=[worker_address_field, worker_port_field, worker_tls_cbx]
|
||||
)
|
||||
|
||||
with gradio.Accordion(label='Advanced'):
|
||||
model_override_dropdown = gradio.Dropdown(label='Model override')
|
||||
model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown])
|
||||
|
||||
with gradio.Row():
|
||||
save_worker_btn = gradio.Button(value='Add/Update Worker')
|
||||
|
|
@ -201,6 +219,12 @@ class UI:
|
|||
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
|
||||
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
|
||||
|
||||
worker_select_dropdown.select(
|
||||
self.populate_worker_config_from_selection,
|
||||
inputs=worker_select_dropdown,
|
||||
outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown]
|
||||
)
|
||||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import queue
|
|||
from modules.shared import state as master_state
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
import scripts.spartan.shared as sh
|
||||
import re
|
||||
|
||||
|
||||
class InvalidWorkerResponse(Exception):
|
||||
|
|
@ -114,6 +115,7 @@ class Worker:
|
|||
self.loaded_vae = ''
|
||||
self.state = State.IDLE
|
||||
self.tls = tls
|
||||
self.model_override: str = None
|
||||
|
||||
if uuid is not None:
|
||||
self.uuid = uuid
|
||||
|
|
@ -278,11 +280,11 @@ class Worker:
|
|||
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
|
||||
|
||||
if sync_options is True:
|
||||
options_response = requests.post(
|
||||
self.full_url("options"),
|
||||
json=option_payload,
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
model = option_payload['sd_model_checkpoint']
|
||||
if self.model_override is not None:
|
||||
model = self.model_override
|
||||
|
||||
self.load_options(model=model, vae=option_payload['sd_vae'])
|
||||
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
|
||||
# second GET to see if everything loaded...
|
||||
|
||||
|
|
@ -517,3 +519,34 @@ class Worker:
|
|||
def mark_unreachable(self):
|
||||
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
|
||||
self.state = State.UNAVAILABLE
|
||||
|
||||
def available_models(self) -> List[str]:
|
||||
response = requests.get(
|
||||
url=self.full_url('sd-models'),
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
|
||||
titles = [model['title'] for model in response.json()]
|
||||
return titles
|
||||
|
||||
def load_options(self, model, vae=None):
|
||||
model_name = re.sub(r'\s?\[[^\]]*\]$', '', model)
|
||||
payload = {
|
||||
"sd_model_checkpoint": model_name
|
||||
}
|
||||
if vae is not None:
|
||||
payload['sd_vae'] = vae
|
||||
|
||||
response = requests.post(
|
||||
self.full_url("options"),
|
||||
json=payload,
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.debug(f"failed to load options for worker '{self.uuid}'")
|
||||
else:
|
||||
self.loaded_model = model_name
|
||||
if vae is not None:
|
||||
self.loaded_vae = vae
|
||||
|
||||
|
|
|
|||
|
|
@ -526,3 +526,8 @@ class World:
|
|||
json.dump(config, config_file, indent=3)
|
||||
logger.debug(f"config saved")
|
||||
|
||||
|
||||
def worker_from_label(self, label: str) -> Worker:
|
||||
for worker in self._workers:
|
||||
if worker.uuid == label:
|
||||
return worker
|
||||
|
|
|
|||
Loading…
Reference in New Issue