use magic

pull/17/head
unknown 2023-07-26 21:10:30 -05:00
parent 352183bbd6
commit e89fbf120f
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 29 additions and 22 deletions

View File

@ -110,13 +110,12 @@ class UI:
return gradio.Dropdown.update(choices=labels)
def populate_worker_config_from_selection(self, selection):
selected_worker: Worker = None
for worker in self.world.get_workers():
if worker.uuid == selection:
selected_worker = worker
avail_models = None
selected_worker = self.world[selection]
avail_models = selected_worker.available_models()
avail_models.append('None') # for disabling override
if avail_models is not None:
avail_models.append('None') # for disabling override
return [
gradio.Textbox.update(value=selected_worker.address),
@ -125,7 +124,6 @@ class UI:
gradio.Dropdown.update(choices=avail_models)
]
def reconnect_remotes(self):
for worker in self.world._workers:
if worker.master:
@ -141,7 +139,7 @@ class UI:
logger.info(f"worker '{worker.uuid}' is still unreachable")
def override_worker_model(self, model, worker_label):
worker = self.world.worker_from_label(worker_label)
worker = self.world[worker_label]
if model == "None":
worker.model_override = None

View File

@ -2,7 +2,7 @@ import io
import gradio
import requests
from typing import List
from typing import List, Union
import math
import copy
import time
@ -506,7 +506,8 @@ class Worker:
try:
response = requests.get(
self.full_url("memory"),
verify=self.verify_remotes
verify=self.verify_remotes,
timeout=3
)
if response.status_code == 200:
return True
@ -517,17 +518,25 @@ class Worker:
return False
def mark_unreachable(self):
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in future")
logger.error(f"Worker '{self.uuid}' at {self} was unreachable, will avoid in the future")
self.state = State.UNAVAILABLE
def available_models(self) -> List[str]:
response = requests.get(
url=self.full_url('sd-models'),
verify=self.verify_remotes
)
def available_models(self) -> Union[List[str], None]:
if self.state == State.UNAVAILABLE:
return None
titles = [model['title'] for model in response.json()]
return titles
try:
response = requests.get(
url=self.full_url('sd-models'),
verify=self.verify_remotes,
timeout=5
)
titles = [model['title'] for model in response.json()]
return titles
except requests.RequestException:
self.mark_unreachable()
return None
def load_options(self, model, vae=None):
model_name = re.sub(r'\s?\[[^\]]*\]$', '', model)

View File

@ -83,6 +83,11 @@ class World:
self.initial_payload = copy.copy(initial_payload)
self.thin_client_mode = False
def __getitem__(self, label: str) -> Worker:
for worker in self._workers:
if worker.uuid == label:
return worker
def update_world(self, total_batch_size):
"""
Updates the world with information vital to handling the local generation request after
@ -526,8 +531,3 @@ class World:
json.dump(config, config_file, indent=3)
logger.debug(f"config saved")
def worker_from_label(self, label: str) -> Worker:
for worker in self._workers:
if worker.uuid == label:
return worker