fix memory leak. override benchmark payload with what is written in workers.json. extra warning when running without any worker info passed
parent
5fc3e4c2e8
commit
f962b8dece
|
|
@ -14,11 +14,11 @@ from typing import List
|
|||
import urllib3
|
||||
import copy
|
||||
from modules.images import save_image
|
||||
from modules.shared import cmd_opts
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.shared import state as webui_state
|
||||
import time
|
||||
from scripts.spartan.World import World, WorldAlreadyInitialized
|
||||
from scripts.spartan.UI import UI
|
||||
from modules.shared import opts
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
from modules.processing import fix_seed, Processed
|
||||
|
|
@ -47,8 +47,12 @@ class Script(scripts.Script):
|
|||
# build world
|
||||
world = World(initial_payload=None, verify_remotes=verify_remotes)
|
||||
# add workers to the world
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
# make sure arguments aren't missing
|
||||
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
else:
|
||||
logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?")
|
||||
|
||||
def title(self):
|
||||
return "Distribute"
|
||||
|
|
@ -64,6 +68,7 @@ class Script(scripts.Script):
|
|||
@staticmethod
|
||||
def add_to_gallery(processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None):
|
||||
image_params: json = response["parameters"]
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ import os
|
|||
import subprocess
|
||||
from pathlib import Path
|
||||
import gradio
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.shared import logger, log_level
|
||||
from scripts.spartan.Worker import State
|
||||
from modules.shared import state as webui_state
|
||||
|
||||
|
||||
class UI:
|
||||
|
|
@ -46,6 +47,11 @@ class UI:
|
|||
def refresh_ckpts_btn(self):
|
||||
self.world.refresh_checkpoints()
|
||||
|
||||
def clear_queue_btn(self):
|
||||
logger.debug(webui_state.__dict__)
|
||||
webui_state.end()
|
||||
|
||||
|
||||
def status_btn(self):
|
||||
worker_status = ''
|
||||
workers = self.world.get_workers()
|
||||
|
|
@ -104,6 +110,11 @@ class UI:
|
|||
redo_benchmarks_btn.style(full_width=False)
|
||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||
|
||||
if log_level == 'DEBUG':
|
||||
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
||||
clear_queue_btn.style(full_width=False)
|
||||
clear_queue_btn.click(self.clear_queue_btn)
|
||||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class Worker:
|
|||
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
|
||||
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
|
||||
logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
|
||||
self.free_vram = bytes(memory_response['free'])
|
||||
self.free_vram = memory_response['free']
|
||||
except KeyError:
|
||||
error = memory_response['cuda']['error']
|
||||
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class World:
|
|||
"""
|
||||
Attempts to benchmark all workers a part of the world.
|
||||
"""
|
||||
global benchmark_payload
|
||||
from scripts.spartan.shared import benchmark_payload
|
||||
|
||||
workers_info: dict = {}
|
||||
saved: bool = os.path.exists(self.worker_info_path)
|
||||
|
|
@ -178,11 +178,18 @@ class World:
|
|||
benchmark_threads = []
|
||||
|
||||
def benchmark_wrapped(worker):
|
||||
logger.critical(f"benchmark payload is: {benchmark_payload}")
|
||||
bench_func = worker.benchmark if not worker.master else self.benchmark_master
|
||||
worker.avg_ipm = bench_func()
|
||||
worker.benchmarked = True
|
||||
|
||||
if rebenchmark:
|
||||
if saved:
|
||||
with open(self.worker_info_path, 'r') as worker_info_file:
|
||||
workers_info = json.load(worker_info_file)
|
||||
benchmark_payload = workers_info['benchmark_payload']
|
||||
logger.info(f"Using saved benchmark config:\n{benchmark_payload}")
|
||||
|
||||
saved = False
|
||||
workers = self.get_workers()
|
||||
|
||||
|
|
@ -194,6 +201,8 @@ class World:
|
|||
with open(self.worker_info_path, 'r') as worker_info_file:
|
||||
try:
|
||||
workers_info = json.load(worker_info_file)
|
||||
benchmark_payload = workers_info['benchmark_payload']
|
||||
logger.info(f"Using saved benchmark config:\n{benchmark_payload}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"workers.json is not valid JSON, regenerating")
|
||||
rebenchmark = True
|
||||
|
|
@ -237,7 +246,6 @@ class World:
|
|||
|
||||
logger.info(self.speed_summary())
|
||||
|
||||
|
||||
def get_current_output_size(self) -> int:
|
||||
"""
|
||||
returns how many images would be returned from all jobs
|
||||
|
|
|
|||
Loading…
Reference in New Issue