diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index 870cf1b..3c53ecb 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -1,6 +1,7 @@ import io import os import subprocess +import threading from pathlib import Path import gradio from scripts.spartan.shared import logger, log_level @@ -8,6 +9,7 @@ from scripts.spartan.Worker import State, Worker from modules.shared import state as webui_state import json from typing import List +from threading import Thread worker_select_dropdown = None @@ -113,7 +115,15 @@ class UI: if worker.uuid == selection: selected_worker = worker - return [gradio.Textbox.update(value=selected_worker.address), gradio.Textbox.update(value=selected_worker.port), gradio.Checkbox.update(value=selected_worker.tls)] + avail_models = selected_worker.available_models() + avail_models.append('None') # for disabling override + + return [ + gradio.Textbox.update(value=selected_worker.address), + gradio.Textbox.update(value=selected_worker.port), + gradio.Checkbox.update(value=selected_worker.tls), + gradio.Dropdown.update(choices=avail_models) + ] def reconnect_remotes(self): @@ -127,6 +137,17 @@ class UI: else: logger.info(f"worker '{worker.uuid}' is still unreachable") + def override_worker_model(self, model, worker_label): + worker = self.world.worker_from_label(worker_label) + + if model == "None": + worker.model_override = None + else: + worker.model_override = model + + # set model on remote early + Thread(target=worker.load_options, args=(model,)).start() + # end handlers @@ -176,8 +197,6 @@ class UI: clear_queue_btn.click(self.clear_queue_btn) with gradio.Tab('Worker Config'): - worker_select_dropdown = None - worker_select_dropdown = gradio.Dropdown( [x.uuid for x in self.selectable_remote_workers()], info='Select a pre-existing worker or enter a label for a new one', @@ -189,11 +208,10 @@ class UI: worker_tls_cbx = gradio.Checkbox( label='connect using https' ) - worker_select_dropdown.select( - self.populate_worker_config_from_selection, - inputs=worker_select_dropdown, - outputs=[worker_address_field, worker_port_field, worker_tls_cbx] - ) + + with gradio.Accordion(label='Advanced'): + model_override_dropdown = gradio.Dropdown(label='Model override') + model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown]) with gradio.Row(): save_worker_btn = gradio.Button(value='Add/Update Worker') @@ -201,6 +219,12 @@ class UI: remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop') remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown]) + worker_select_dropdown.select( + self.populate_worker_config_from_selection, + inputs=worker_select_dropdown, + outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown] + ) + with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( label='Thin-client mode (experimental)', diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 253a98b..0dccd84 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -18,6 +18,7 @@ import queue from modules.shared import state as master_state from modules.api.api import encode_pil_to_base64 import scripts.spartan.shared as sh +import re class InvalidWorkerResponse(Exception): @@ -114,6 +115,7 @@ class Worker: self.loaded_vae = '' self.state = State.IDLE self.tls = tls + self.model_override: str = None if uuid is not None: self.uuid = uuid @@ -278,11 +280,11 @@ class Worker: logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}") if sync_options is True: - options_response = requests.post( - self.full_url("options"), - json=option_payload, - verify=self.verify_remotes - ) + model = option_payload['sd_model_checkpoint'] + if self.model_override is not None: + model = self.model_override + + self.load_options(model=model, vae=option_payload['sd_vae']) # TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a # second GET to see if everything loaded... @@ -517,3 +519,34 @@ class Worker: def mark_unreachable(self): logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future") self.state = State.UNAVAILABLE + + def available_models(self) -> List[str]: + response = requests.get( + url=self.full_url('sd-models'), + verify=self.verify_remotes + ) + + titles = [model['title'] for model in response.json()] + return titles + + def load_options(self, model, vae=None): + model_name = re.sub(r'\s?\[[^\]]*\]$', '', model) + payload = { + "sd_model_checkpoint": model_name + } + if vae is not None: + payload['sd_vae'] = vae + + response = requests.post( + self.full_url("options"), + json=payload, + verify=self.verify_remotes + ) + + if response.status_code != 200: + logger.debug(f"failed to load options for worker '{self.uuid}'") + else: + self.loaded_model = model_name + if vae is not None: + self.loaded_vae = vae + diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 8299131..2a17a41 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -526,3 +526,8 @@ class World: json.dump(config, config_file, indent=3) logger.debug(f"config saved") + + def worker_from_label(self, label: str) -> Worker: + for worker in self._workers: + if worker.uuid == label: + return worker