diff --git a/scripts/extension.py b/scripts/extension.py index cbcae5c..f860e2b 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -6,9 +6,7 @@ import base64 import io import json import re - import gradio - from modules import scripts, script_callbacks from modules import processing from threading import Thread @@ -23,9 +21,6 @@ from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.Worker import Worker from modules.shared import opts -path_root = scripts.basedir() - - # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? # TODO see if the current api has some sort of UUID generation functionality. diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 2979ab3..948e3ba 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -1,3 +1,4 @@ +import gradio import requests from typing import List import math @@ -248,65 +249,68 @@ class Worker: """ eta = None - # TODO handle no connection exception and remove worker (for this request) in that case # TODO detect remote out of memory exception and restart or garbage collect instance using api? - # query memory available on worker and store for future reference - if self.queried is False: - self.queried = True - memory_response = requests.get( - self.full_url("memory"), - verify=self.verify_remotes - ) - memory_response = memory_response.json()['cuda']['system'] # all in bytes - - free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) - total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) - print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") - self.free_vram = bytes(memory_response['free']) - - if sync_options is True: - options_response = requests.post( - self.full_url("options"), - json=option_payload, - verify=self.verify_remotes - ) - self.response = options_response - # TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a - # second GET to see if everything loaded... - - if self.benchmarked: - eta = self.batch_eta(payload=payload) - print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image(" - f"s) at a speed of {self.avg_ipm} ipm\n") - try: - start = time.time() - response = requests.post( - self.full_url("txt2img"), - json=payload, - verify=self.verify_remotes - ) - self.response = response.json() + # query memory available on worker and store for future reference + if self.queried is False: + self.queried = True + memory_response = requests.get( + self.full_url("memory"), + verify=self.verify_remotes + ) + memory_response = memory_response.json()['cuda']['system'] # all in bytes - # update list of ETA accuracy - if self.benchmarked and not self.interrupted: - self.response_time = time.time() - start - variance = ((eta - self.response_time) / self.response_time) * 100 + free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) + total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) + print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") + self.free_vram = bytes(memory_response['free']) - if cmd_opts.distributed_debug: - print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") - print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") + if sync_options is True: + options_response = requests.post( + self.full_url("options"), + json=option_payload, + verify=self.verify_remotes + ) + self.response = options_response + # TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a + # second GET to see if everything loaded... - if self.eta_percent_error == 0: - self.eta_percent_error[0] = variance + if self.benchmarked: + eta = self.batch_eta(payload=payload) + print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image(" + f"s) at a speed of {self.avg_ipm} ipm\n") + + try: + start = time.time() + response = requests.post( + self.full_url("txt2img"), + json=payload, + verify=self.verify_remotes + ) + self.response = response.json() + + # update list of ETA accuracy + if self.benchmarked and not self.interrupted: + self.response_time = time.time() - start + variance = ((eta - self.response_time) / self.response_time) * 100 + + if cmd_opts.distributed_debug: + print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") + print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") + + if self.eta_percent_error == 0: + self.eta_percent_error[0] = variance + else: + self.eta_percent_error.append(variance) + + except Exception as e: + if payload['batch_size'] == 0: + raise InvalidWorkerResponse("Tried to request a null amount of images") else: - self.eta_percent_error.append(variance) + raise InvalidWorkerResponse(e) - except Exception as e: - if payload['batch_size'] == 0: - raise InvalidWorkerResponse("Tried to request a null amount of images") - else: - raise InvalidWorkerResponse(e) + except requests.exceptions.ConnectTimeout: + print(f"\nTimed out waiting for worker '{self.uuid}' at {self}") return