253 lines
10 KiB
Python
253 lines
10 KiB
Python
import io
|
|
import os
|
|
import subprocess
|
|
import threading
|
|
from pathlib import Path
|
|
import gradio
|
|
from scripts.spartan.shared import logger, log_level
|
|
from scripts.spartan.Worker import State, Worker
|
|
from modules.shared import state as webui_state
|
|
import json
|
|
from typing import List
|
|
from threading import Thread
|
|
|
|
worker_select_dropdown = None
|
|
|
|
class UI:
|
|
|
|
def __init__(self, script, world):
|
|
self.script = script
|
|
self.world = world
|
|
|
|
# handlers
|
|
@staticmethod
|
|
def user_script_btn():
|
|
user_scripts = Path(os.path.abspath(__file__)).parent.parent.joinpath('user')
|
|
|
|
for file in user_scripts.iterdir():
|
|
logger.debug(f"found possible script {file.name}")
|
|
if file.is_file() and file.name.startswith('sync'):
|
|
user_script = file
|
|
|
|
suffix = user_script.suffix[1:]
|
|
|
|
if suffix == 'ps1':
|
|
subprocess.call(['powershell', user_script])
|
|
return True
|
|
else:
|
|
f = open(user_script, 'r')
|
|
first_line = f.readline().strip()
|
|
if first_line.startswith('#!'):
|
|
shebang = first_line[2:]
|
|
subprocess.call([shebang, user_script])
|
|
return True
|
|
|
|
return False
|
|
|
|
def benchmark_btn(self):
|
|
logger.info("Redoing benchmarks...")
|
|
self.world.benchmark(rebenchmark=True)
|
|
|
|
def clear_queue_btn(self):
|
|
logger.debug(webui_state.__dict__)
|
|
webui_state.end()
|
|
|
|
|
|
def status_btn(self):
|
|
worker_status = ''
|
|
workers = self.world.get_workers()
|
|
|
|
for worker in workers:
|
|
if worker.master:
|
|
continue
|
|
|
|
worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
|
|
|
# TODO replace this with a single check to a state flag that we should make in the world class
|
|
for worker in workers:
|
|
if worker.state == State.WORKING:
|
|
return self.world.__str__(), worker_status
|
|
|
|
return 'No active jobs!', worker_status
|
|
|
|
def save_btn(self, thin_client_mode, job_timeout):
|
|
self.world.thin_client_mode = thin_client_mode
|
|
logger.debug(f"thin client mode is now {thin_client_mode}")
|
|
job_timeout = int(job_timeout)
|
|
self.world.job_timeout = job_timeout
|
|
logger.debug(f"job timeout is now {job_timeout} seconds")
|
|
|
|
def save_worker_btn(self, name, address, port, tls):
|
|
self.world.add_worker(name, address, port, tls)
|
|
self.world.save_config()
|
|
|
|
# visibly update which workers can be selected
|
|
labels = [x.uuid for x in self.selectable_remote_workers()]
|
|
return gradio.Dropdown.update(choices=labels)
|
|
|
|
def selectable_remote_workers(self) -> List[Worker]:
|
|
remote_workers = []
|
|
|
|
for worker in self.world.get_workers():
|
|
if worker.master:
|
|
continue
|
|
remote_workers.append(worker)
|
|
remote_workers = sorted(remote_workers, key=lambda x: x.uuid)
|
|
|
|
return remote_workers
|
|
|
|
def remove_worker_btn(self, worker_label):
|
|
# remove worker from memory
|
|
for worker in self.world._workers:
|
|
if worker.uuid == worker_label:
|
|
self.world._workers.remove(worker)
|
|
|
|
# remove worker from disk
|
|
self.world.save_config()
|
|
|
|
# visibly update which workers can be selected
|
|
labels = [x.uuid for x in self.selectable_remote_workers()]
|
|
return gradio.Dropdown.update(choices=labels)
|
|
|
|
def populate_worker_config_from_selection(self, selection):
|
|
avail_models = None
|
|
selected_worker = self.world[selection]
|
|
|
|
avail_models = selected_worker.available_models()
|
|
if avail_models is not None:
|
|
avail_models.append('None') # for disabling override
|
|
|
|
return [
|
|
gradio.Textbox.update(value=selected_worker.address),
|
|
gradio.Textbox.update(value=selected_worker.port),
|
|
gradio.Checkbox.update(value=selected_worker.tls),
|
|
gradio.Dropdown.update(choices=avail_models)
|
|
]
|
|
|
|
def reconnect_remotes(self):
|
|
for worker in self.world._workers:
|
|
if worker.master:
|
|
continue
|
|
|
|
logger.debug(f"checking if worker '{worker.uuid}' is now reachable...")
|
|
reachable = worker.reachable()
|
|
if worker.state == State.UNAVAILABLE:
|
|
if reachable:
|
|
logger.info(f"worker '{worker.uuid}' is now online, marking as available")
|
|
worker.state = State.IDLE
|
|
else:
|
|
logger.info(f"worker '{worker.uuid}' is still unreachable")
|
|
|
|
def override_worker_model(self, model, worker_label):
|
|
worker = self.world[worker_label]
|
|
|
|
if model == "None":
|
|
worker.model_override = None
|
|
else:
|
|
worker.model_override = model
|
|
|
|
# set model on remote early
|
|
Thread(target=worker.load_options, args=(model,)).start()
|
|
|
|
|
|
# end handlers
|
|
|
|
def create_ui(self):
|
|
with gradio.Box() as root:
|
|
with gradio.Accordion(label='Distributed', open=False):
|
|
with gradio.Tab('Status') as status_tab:
|
|
status = gradio.Textbox(elem_id='status', show_label=False)
|
|
status.placeholder = 'Refresh!'
|
|
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
|
|
jobs.placeholder = 'Refresh!'
|
|
|
|
refresh_status_btn = gradio.Button(value='Refresh')
|
|
refresh_status_btn.style(size='sm')
|
|
refresh_status_btn.click(self.status_btn, inputs=[], outputs=[jobs, status])
|
|
|
|
status_tab.select(fn=self.status_btn, inputs=[], outputs=[jobs, status])
|
|
|
|
with gradio.Tab('Utils'):
|
|
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
|
refresh_checkpoints_btn.style(full_width=False)
|
|
refresh_checkpoints_btn.click(self.world.refresh_checkpoints)
|
|
|
|
run_usr_btn = gradio.Button(value='Run user script')
|
|
run_usr_btn.style(full_width=False)
|
|
run_usr_btn.click(self.user_script_btn)
|
|
|
|
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
|
|
interrupt_all_btn.style(full_width=False)
|
|
interrupt_all_btn.click(self.world.interrupt_remotes)
|
|
|
|
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop')
|
|
redo_benchmarks_btn.style(full_width=False)
|
|
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
|
|
|
reload_config_btn = gradio.Button(value='Reload config from file')
|
|
reload_config_btn.style(full_width=False)
|
|
reload_config_btn.click(self.world.load_config)
|
|
|
|
reconnect_lost_workers_btn = gradio.Button(value='Attempt reconnection with remotes')
|
|
reconnect_lost_workers_btn.style(full_width=False)
|
|
reconnect_lost_workers_btn.click(self.reconnect_remotes)
|
|
|
|
if log_level == 'DEBUG':
|
|
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
|
clear_queue_btn.style(full_width=False)
|
|
clear_queue_btn.click(self.clear_queue_btn)
|
|
|
|
with gradio.Tab('Worker Config'):
|
|
worker_select_dropdown = gradio.Dropdown(
|
|
[x.uuid for x in self.selectable_remote_workers()],
|
|
info='Select a pre-existing worker or enter a label for a new one',
|
|
label='Label',
|
|
allow_custom_value=True
|
|
)
|
|
worker_address_field = gradio.Textbox(label='Address', placeholder='localhost')
|
|
worker_port_field = gradio.Textbox(label='Port', placeholder='7860')
|
|
worker_tls_cbx = gradio.Checkbox(
|
|
label='connect using https'
|
|
)
|
|
|
|
with gradio.Accordion(label='Advanced'):
|
|
model_override_dropdown = gradio.Dropdown(label='Model override')
|
|
model_override_dropdown.select(self.override_worker_model, inputs=[model_override_dropdown, worker_select_dropdown])
|
|
|
|
with gradio.Row():
|
|
save_worker_btn = gradio.Button(value='Add/Update Worker')
|
|
save_worker_btn.click(self.save_worker_btn, inputs=[worker_select_dropdown, worker_address_field, worker_port_field, worker_tls_cbx], outputs=[worker_select_dropdown])
|
|
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
|
|
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
|
|
|
|
worker_select_dropdown.select(
|
|
self.populate_worker_config_from_selection,
|
|
inputs=worker_select_dropdown,
|
|
outputs=[worker_address_field, worker_port_field, worker_tls_cbx, model_override_dropdown]
|
|
)
|
|
|
|
with gradio.Tab('Settings'):
|
|
thin_client_cbx = gradio.Checkbox(
|
|
label='Thin-client mode (experimental)',
|
|
info="Only generate images using remote workers. There will be no previews when enabled.",
|
|
value=False
|
|
)
|
|
job_timeout = gradio.Number(
|
|
label='Job timeout', value=self.world.job_timeout,
|
|
info="Seconds until a worker is considered too slow to be assigned an"
|
|
" equal share of the total request. Longer than 2 seconds is recommended."
|
|
)
|
|
|
|
save_btn = gradio.Button(value='Update')
|
|
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout])
|
|
|
|
with gradio.Tab('Help'):
|
|
gradio.Markdown(
|
|
"""
|
|
- [Discord Server 🤝](https://discord.gg/Jpc8wnftd4)
|
|
- [Github Repository </>](https://github.com/papuSpartan/stable-diffusion-webui-distributed)
|
|
"""
|
|
)
|
|
|
|
return root, [thin_client_cbx, job_timeout]
|