fix memory leak. override benchmark payload with what is written in workers.json. extra warning when running without any worker info passed

pull/17/head
unknown 2023-06-28 07:13:58 -05:00
parent 5fc3e4c2e8
commit f962b8dece
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 32 additions and 8 deletions

View File

@ -14,11 +14,11 @@ from typing import List
import urllib3
import copy
from modules.images import save_image
from modules.shared import cmd_opts
from modules.shared import opts, cmd_opts
from modules.shared import state as webui_state
import time
from scripts.spartan.World import World, WorldAlreadyInitialized
from scripts.spartan.UI import UI
from modules.shared import opts
from scripts.spartan.shared import logger
from scripts.spartan.control_net import pack_control_net
from modules.processing import fix_seed, Processed
@ -47,8 +47,12 @@ class Script(scripts.Script):
# build world
world = World(initial_payload=None, verify_remotes=verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
# make sure arguments aren't missing
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
for worker in cmd_opts.distributed_remotes:
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
else:
logger.fatal(f"Found no worker info passed as arguments. Did you populate --distributed-remotes ?")
def title(self):
return "Distribute"
@ -64,6 +68,7 @@ class Script(scripts.Script):
@staticmethod
def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
webui_state.textinfo = "Distributed - injecting images"
def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None):
image_params: json = response["parameters"]

View File

@ -2,8 +2,9 @@ import os
import subprocess
from pathlib import Path
import gradio
from scripts.spartan.shared import logger
from scripts.spartan.shared import logger, log_level
from scripts.spartan.Worker import State
from modules.shared import state as webui_state
class UI:
@ -46,6 +47,11 @@ class UI:
def refresh_ckpts_btn(self):
self.world.refresh_checkpoints()
def clear_queue_btn(self):
logger.debug(webui_state.__dict__)
webui_state.end()
def status_btn(self):
worker_status = ''
workers = self.world.get_workers()
@ -104,6 +110,11 @@ class UI:
redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
if log_level == 'DEBUG':
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
clear_queue_btn.style(full_width=False)
clear_queue_btn.click(self.clear_queue_btn)
with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)',

View File

@ -281,7 +281,7 @@ class Worker:
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = bytes(memory_response['free'])
self.free_vram = memory_response['free']
except KeyError:
error = memory_response['cuda']['error']
logger.debug(f"CUDA doesn't seem to be available for worker '{self.uuid}'\nError: {error}")

View File

@ -170,7 +170,7 @@ class World:
"""
Attempts to benchmark all workers a part of the world.
"""
global benchmark_payload
from scripts.spartan.shared import benchmark_payload
workers_info: dict = {}
saved: bool = os.path.exists(self.worker_info_path)
@ -178,11 +178,18 @@ class World:
benchmark_threads = []
def benchmark_wrapped(worker):
logger.critical(f"benchmark payload is: {benchmark_payload}")
bench_func = worker.benchmark if not worker.master else self.benchmark_master
worker.avg_ipm = bench_func()
worker.benchmarked = True
if rebenchmark:
if saved:
with open(self.worker_info_path, 'r') as worker_info_file:
workers_info = json.load(worker_info_file)
benchmark_payload = workers_info['benchmark_payload']
logger.info(f"Using saved benchmark config:\n{benchmark_payload}")
saved = False
workers = self.get_workers()
@ -194,6 +201,8 @@ class World:
with open(self.worker_info_path, 'r') as worker_info_file:
try:
workers_info = json.load(worker_info_file)
benchmark_payload = workers_info['benchmark_payload']
logger.info(f"Using saved benchmark config:\n{benchmark_payload}")
except json.JSONDecodeError:
logger.error(f"workers.json is not valid JSON, regenerating")
rebenchmark = True
@ -237,7 +246,6 @@ class World:
logger.info(self.speed_summary())
def get_current_output_size(self) -> int:
"""
returns how many images would be returned from all jobs