inform in the case of worker timeout

pull/2/head
unknown 2023-03-24 13:25:02 -05:00
parent 46929e949a
commit 162b541d4e
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 56 additions and 57 deletions

View File

@ -6,9 +6,7 @@ import base64
import io
import json
import re
import gradio
from modules import scripts, script_callbacks
from modules import processing
from threading import Thread
@ -23,9 +21,6 @@ from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker
from modules.shared import opts
path_root = scripts.basedir()
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# TODO see if the current api has some sort of UUID generation functionality.

View File

@ -1,3 +1,4 @@
import gradio
import requests
from typing import List
import math
@ -248,65 +249,68 @@ class Worker:
"""
eta = None
# TODO handle no connection exception and remove worker (for this request) in that case
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
# query memory available on worker and store for future reference
if self.queried is False:
self.queried = True
memory_response = requests.get(
self.full_url("memory"),
verify=self.verify_remotes
)
memory_response = memory_response.json()['cuda']['system'] # all in bytes
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = bytes(memory_response['free'])
if sync_options is True:
options_response = requests.post(
self.full_url("options"),
json=option_payload,
verify=self.verify_remotes
)
self.response = options_response
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
# second GET to see if everything loaded...
if self.benchmarked:
eta = self.batch_eta(payload=payload)
print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
f"s) at a speed of {self.avg_ipm} ipm\n")
try:
start = time.time()
response = requests.post(
self.full_url("txt2img"),
json=payload,
verify=self.verify_remotes
)
self.response = response.json()
# query memory available on worker and store for future reference
if self.queried is False:
self.queried = True
memory_response = requests.get(
self.full_url("memory"),
verify=self.verify_remotes
)
memory_response = memory_response.json()['cuda']['system'] # all in bytes
# update list of ETA accuracy
if self.benchmarked and not self.interrupted:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = bytes(memory_response['free'])
if cmd_opts.distributed_debug:
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
if sync_options is True:
options_response = requests.post(
self.full_url("options"),
json=option_payload,
verify=self.verify_remotes
)
self.response = options_response
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
# second GET to see if everything loaded...
if self.eta_percent_error == 0:
self.eta_percent_error[0] = variance
if self.benchmarked:
eta = self.batch_eta(payload=payload)
print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
f"s) at a speed of {self.avg_ipm} ipm\n")
try:
start = time.time()
response = requests.post(
self.full_url("txt2img"),
json=payload,
verify=self.verify_remotes
)
self.response = response.json()
# update list of ETA accuracy
if self.benchmarked and not self.interrupted:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
if cmd_opts.distributed_debug:
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
if self.eta_percent_error == 0:
self.eta_percent_error[0] = variance
else:
self.eta_percent_error.append(variance)
except Exception as e:
if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images")
else:
self.eta_percent_error.append(variance)
raise InvalidWorkerResponse(e)
except Exception as e:
if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images")
else:
raise InvalidWorkerResponse(e)
except requests.exceptions.ConnectTimeout:
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
return