feature: per-worker model override

pull/17/head
unknown 2023-07-15 05:45:55 -05:00
parent 9d48e2732d
commit bd5e6ac7b7
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 75 additions and 13 deletions

View File

@ -1,6 +1,7 @@
import io
import os
import subprocess
import threading
from pathlib import Path
import gradio
from scripts.spartan.shared import logger, log_level
@ -8,6 +9,7 @@ from scripts.spartan.Worker import State, Worker
from modules.shared import state as webui_state
import json
from typing import List
from threading import Thread
worker_select_dropdown = None
@ -113,7 +115,15 @@ class UI:
if worker.uuid == selection:
selected_worker = worker
return [gradio.Textbox.update(value=selected_worker.address), gradio.Textbox.update(value=selected_worker.port), gradio.Checkbox.update(value=selected_worker.tls)]
avail_models = selected_worker.available_models()
avail_models.append('None') # for disabling override
return [
gradio.Textbox.update(value=selected_worker.address),
gradio.Textbox.update(value=selected_worker.port),
gradio.Checkbox.update(value=selected_worker.tls),
gradio.Dropdown.update(choices=avail_models)
]
def reconnect_remotes(self):
@ -127,6 +137,17 @@ class UI:
else:
logger.info(f"worker '{worker.uuid}' is still unreachable")
def override_worker_model(self, model, worker_label):
worker = self.world.worker_from_label(worker_label)
if model == "None":
worker.model_override = None
else:
worker.model_override = model
# set model on remote early
Thread(target=worker.load_options, args=(model,)).start()
# end handlers
@ -176,8 +197,6 @@ class UI:
clear_queue_btn.click(self.clear_queue_btn)
with gradio.Tab('Worker Config'):
worker_select_dropdown = None
worker_select_dropdown = gradio.Dropdown(
[x.uuid for x in self.selectable_remote_workers()],
info='Select a pre-existing worker or enter a label for a new one',
@ -189,11 +208,10 @@ class UI:
worker_tls_cbx = gradio.Checkbox(
label='connect using https'
)
worker_select_dropdown.select(
self.populate_worker_config_from_selection,
inputs=worker_select_dropdown,
outputs=[worker_address_field, worker_port_field, worker_tls_cbx]
)
with gradio.Accordion(label='Advanced'):
model_override_dropdown = gradio.Dropdown(label='Model override')
model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown])
with gradio.Row():
save_worker_btn = gradio.Button(value='Add/Update Worker')
@ -201,6 +219,12 @@ class UI:
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
worker_select_dropdown.select(
self.populate_worker_config_from_selection,
inputs=worker_select_dropdown,
outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown]
)
with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)',

View File

@ -18,6 +18,7 @@ import queue
from modules.shared import state as master_state
from modules.api.api import encode_pil_to_base64
import scripts.spartan.shared as sh
import re
class InvalidWorkerResponse(Exception):
@ -114,6 +115,7 @@ class Worker:
self.loaded_vae = ''
self.state = State.IDLE
self.tls = tls
self.model_override: str = None
if uuid is not None:
self.uuid = uuid
@ -278,11 +280,11 @@ class Worker:
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
if sync_options is True:
options_response = requests.post(
self.full_url("options"),
json=option_payload,
verify=self.verify_remotes
)
model = option_payload['sd_model_checkpoint']
if self.model_override is not None:
model = self.model_override
self.load_options(model=model, vae=option_payload['sd_vae'])
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
# second GET to see if everything loaded...
@ -517,3 +519,34 @@ class Worker:
def mark_unreachable(self):
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
self.state = State.UNAVAILABLE
def available_models(self) -> List[str]:
response = requests.get(
url=self.full_url('sd-models'),
verify=self.verify_remotes
)
titles = [model['title'] for model in response.json()]
return titles
def load_options(self, model, vae=None):
model_name = re.sub(r'\s?\[[^\]]*\]$', '', model)
payload = {
"sd_model_checkpoint": model_name
}
if vae is not None:
payload['sd_vae'] = vae
response = requests.post(
self.full_url("options"),
json=payload,
verify=self.verify_remotes
)
if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.uuid}'")
else:
self.loaded_model = model_name
if vae is not None:
self.loaded_vae = vae

View File

@ -526,3 +526,8 @@ class World:
json.dump(config, config_file, indent=3)
logger.debug(f"config saved")
def worker_from_label(self, label: str) -> Worker:
for worker in self._workers:
if worker.uuid == label:
return worker