From e89fbf120ffef36f9d258994dddb572fc3ff730e Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 26 Jul 2023 21:10:30 -0500 Subject: [PATCH] use magic --- scripts/spartan/UI.py | 12 +++++------- scripts/spartan/Worker.py | 29 +++++++++++++++++++---------- scripts/spartan/World.py | 10 +++++----- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index 55f4767..cfc8b25 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -110,13 +110,12 @@ class UI: return gradio.Dropdown.update(choices=labels) def populate_worker_config_from_selection(self, selection): - selected_worker: Worker = None - for worker in self.world.get_workers(): - if worker.uuid == selection: - selected_worker = worker + avail_models = None + selected_worker = self.world[selection] avail_models = selected_worker.available_models() - avail_models.append('None') # for disabling override + if avail_models is not None: + avail_models.append('None') # for disabling override return [ gradio.Textbox.update(value=selected_worker.address), @@ -125,7 +124,6 @@ class UI: gradio.Dropdown.update(choices=avail_models) ] - def reconnect_remotes(self): for worker in self.world._workers: if worker.master: @@ -141,7 +139,7 @@ class UI: logger.info(f"worker '{worker.uuid}' is still unreachable") def override_worker_model(self, model, worker_label): - worker = self.world.worker_from_label(worker_label) + worker = self.world[worker_label] if model == "None": worker.model_override = None diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 0dccd84..7d83402 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -2,7 +2,7 @@ import io import gradio import requests -from typing import List +from typing import List, Union import math import copy import time @@ -506,7 +506,8 @@ class Worker: try: response = requests.get( self.full_url("memory"), - verify=self.verify_remotes + verify=self.verify_remotes, + timeout=3 ) if response.status_code == 200: return True @@ -517,17 +518,25 @@ class Worker: return False def mark_unreachable(self): - logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future") + logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in the future") self.state = State.UNAVAILABLE - def available_models(self) -> List[str]: - response = requests.get( - url=self.full_url('sd-models'), - verify=self.verify_remotes - ) + def available_models(self) -> Union[List[str], None]: + if self.state == State.UNAVAILABLE: + return None - titles = [model['title'] for model in response.json()] - return titles + try: + response = requests.get( + url=self.full_url('sd-models'), + verify=self.verify_remotes, + timeout=5 + ) + + titles = [model['title'] for model in response.json()] + return titles + except requests.RequestException: + self.mark_unreachable() + return None def load_options(self, model, vae=None): model_name = re.sub(r'\s?\[[^\]]*\]$', '', model) diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 2a17a41..71bf939 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -83,6 +83,11 @@ class World: self.initial_payload = copy.copy(initial_payload) self.thin_client_mode = False + def __getitem__(self, label: str) -> Worker: + for worker in self._workers: + if worker.uuid == label: + return worker + def update_world(self, total_batch_size): """ Updates the world with information vital to handling the local generation request after @@ -526,8 +531,3 @@ class World: json.dump(config, config_file, indent=3) logger.debug(f"config saved") - - def worker_from_label(self, label: str) -> Worker: - for worker in self._workers: - if worker.uuid == label: - return worker