feature: per-worker model override
parent
9d48e2732d
commit
bd5e6ac7b7
|
|
@ -1,6 +1,7 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gradio
|
import gradio
|
||||||
from scripts.spartan.shared import logger, log_level
|
from scripts.spartan.shared import logger, log_level
|
||||||
|
|
@ -8,6 +9,7 @@ from scripts.spartan.Worker import State, Worker
|
||||||
from modules.shared import state as webui_state
|
from modules.shared import state as webui_state
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
worker_select_dropdown = None
|
worker_select_dropdown = None
|
||||||
|
|
||||||
|
|
@ -113,7 +115,15 @@ class UI:
|
||||||
if worker.uuid == selection:
|
if worker.uuid == selection:
|
||||||
selected_worker = worker
|
selected_worker = worker
|
||||||
|
|
||||||
return [gradio.Textbox.update(value=selected_worker.address), gradio.Textbox.update(value=selected_worker.port), gradio.Checkbox.update(value=selected_worker.tls)]
|
avail_models = selected_worker.available_models()
|
||||||
|
avail_models.append('None') # for disabling override
|
||||||
|
|
||||||
|
return [
|
||||||
|
gradio.Textbox.update(value=selected_worker.address),
|
||||||
|
gradio.Textbox.update(value=selected_worker.port),
|
||||||
|
gradio.Checkbox.update(value=selected_worker.tls),
|
||||||
|
gradio.Dropdown.update(choices=avail_models)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def reconnect_remotes(self):
|
def reconnect_remotes(self):
|
||||||
|
|
@ -127,6 +137,17 @@ class UI:
|
||||||
else:
|
else:
|
||||||
logger.info(f"worker '{worker.uuid}' is still unreachable")
|
logger.info(f"worker '{worker.uuid}' is still unreachable")
|
||||||
|
|
||||||
|
def override_worker_model(self, model, worker_label):
|
||||||
|
worker = self.world.worker_from_label(worker_label)
|
||||||
|
|
||||||
|
if model == "None":
|
||||||
|
worker.model_override = None
|
||||||
|
else:
|
||||||
|
worker.model_override = model
|
||||||
|
|
||||||
|
# set model on remote early
|
||||||
|
Thread(target=worker.load_options, args=(model,)).start()
|
||||||
|
|
||||||
|
|
||||||
# end handlers
|
# end handlers
|
||||||
|
|
||||||
|
|
@ -176,8 +197,6 @@ class UI:
|
||||||
clear_queue_btn.click(self.clear_queue_btn)
|
clear_queue_btn.click(self.clear_queue_btn)
|
||||||
|
|
||||||
with gradio.Tab('Worker Config'):
|
with gradio.Tab('Worker Config'):
|
||||||
worker_select_dropdown = None
|
|
||||||
|
|
||||||
worker_select_dropdown = gradio.Dropdown(
|
worker_select_dropdown = gradio.Dropdown(
|
||||||
[x.uuid for x in self.selectable_remote_workers()],
|
[x.uuid for x in self.selectable_remote_workers()],
|
||||||
info='Select a pre-existing worker or enter a label for a new one',
|
info='Select a pre-existing worker or enter a label for a new one',
|
||||||
|
|
@ -189,11 +208,10 @@ class UI:
|
||||||
worker_tls_cbx = gradio.Checkbox(
|
worker_tls_cbx = gradio.Checkbox(
|
||||||
label='connect using https'
|
label='connect using https'
|
||||||
)
|
)
|
||||||
worker_select_dropdown.select(
|
|
||||||
self.populate_worker_config_from_selection,
|
with gradio.Accordion(label='Advanced'):
|
||||||
inputs=worker_select_dropdown,
|
model_override_dropdown = gradio.Dropdown(label='Model override')
|
||||||
outputs=[worker_address_field, worker_port_field, worker_tls_cbx]
|
model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown])
|
||||||
)
|
|
||||||
|
|
||||||
with gradio.Row():
|
with gradio.Row():
|
||||||
save_worker_btn = gradio.Button(value='Add/Update Worker')
|
save_worker_btn = gradio.Button(value='Add/Update Worker')
|
||||||
|
|
@ -201,6 +219,12 @@ class UI:
|
||||||
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
|
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
|
||||||
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
|
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
|
||||||
|
|
||||||
|
worker_select_dropdown.select(
|
||||||
|
self.populate_worker_config_from_selection,
|
||||||
|
inputs=worker_select_dropdown,
|
||||||
|
outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown]
|
||||||
|
)
|
||||||
|
|
||||||
with gradio.Tab('Settings'):
|
with gradio.Tab('Settings'):
|
||||||
thin_client_cbx = gradio.Checkbox(
|
thin_client_cbx = gradio.Checkbox(
|
||||||
label='Thin-client mode (experimental)',
|
label='Thin-client mode (experimental)',
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import queue
|
||||||
from modules.shared import state as master_state
|
from modules.shared import state as master_state
|
||||||
from modules.api.api import encode_pil_to_base64
|
from modules.api.api import encode_pil_to_base64
|
||||||
import scripts.spartan.shared as sh
|
import scripts.spartan.shared as sh
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
class InvalidWorkerResponse(Exception):
|
class InvalidWorkerResponse(Exception):
|
||||||
|
|
@ -114,6 +115,7 @@ class Worker:
|
||||||
self.loaded_vae = ''
|
self.loaded_vae = ''
|
||||||
self.state = State.IDLE
|
self.state = State.IDLE
|
||||||
self.tls = tls
|
self.tls = tls
|
||||||
|
self.model_override: str = None
|
||||||
|
|
||||||
if uuid is not None:
|
if uuid is not None:
|
||||||
self.uuid = uuid
|
self.uuid = uuid
|
||||||
|
|
@ -278,11 +280,11 @@ class Worker:
|
||||||
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
|
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
|
||||||
|
|
||||||
if sync_options is True:
|
if sync_options is True:
|
||||||
options_response = requests.post(
|
model = option_payload['sd_model_checkpoint']
|
||||||
self.full_url("options"),
|
if self.model_override is not None:
|
||||||
json=option_payload,
|
model = self.model_override
|
||||||
verify=self.verify_remotes
|
|
||||||
)
|
self.load_options(model=model, vae=option_payload['sd_vae'])
|
||||||
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
|
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
|
||||||
# second GET to see if everything loaded...
|
# second GET to see if everything loaded...
|
||||||
|
|
||||||
|
|
@ -517,3 +519,34 @@ class Worker:
|
||||||
def mark_unreachable(self):
|
def mark_unreachable(self):
|
||||||
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
|
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
|
||||||
self.state = State.UNAVAILABLE
|
self.state = State.UNAVAILABLE
|
||||||
|
|
||||||
|
def available_models(self) -> List[str]:
|
||||||
|
response = requests.get(
|
||||||
|
url=self.full_url('sd-models'),
|
||||||
|
verify=self.verify_remotes
|
||||||
|
)
|
||||||
|
|
||||||
|
titles = [model['title'] for model in response.json()]
|
||||||
|
return titles
|
||||||
|
|
||||||
|
def load_options(self, model, vae=None):
|
||||||
|
model_name = re.sub(r'\s?\[[^\]]*\]$', '', model)
|
||||||
|
payload = {
|
||||||
|
"sd_model_checkpoint": model_name
|
||||||
|
}
|
||||||
|
if vae is not None:
|
||||||
|
payload['sd_vae'] = vae
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.full_url("options"),
|
||||||
|
json=payload,
|
||||||
|
verify=self.verify_remotes
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.debug(f"failed to load options for worker '{self.uuid}'")
|
||||||
|
else:
|
||||||
|
self.loaded_model = model_name
|
||||||
|
if vae is not None:
|
||||||
|
self.loaded_vae = vae
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -526,3 +526,8 @@ class World:
|
||||||
json.dump(config, config_file, indent=3)
|
json.dump(config, config_file, indent=3)
|
||||||
logger.debug(f"config saved")
|
logger.debug(f"config saved")
|
||||||
|
|
||||||
|
|
||||||
|
def worker_from_label(self, label: str) -> Worker:
|
||||||
|
for worker in self._workers:
|
||||||
|
if worker.uuid == label:
|
||||||
|
return worker
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue