simplify worker.request by extraction
parent
f819f55743
commit
61aa5e5e06
|
|
@ -300,6 +300,90 @@ class Worker:
|
|||
|
||||
return eta
|
||||
|
||||
def matching_scripts(self, payload: dict, mode) -> dict:
|
||||
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
|
||||
if alwayson_scripts is not None:
|
||||
if len(self.supported_scripts) <= 0:
|
||||
payload['alwayson_scripts'] = {}
|
||||
else:
|
||||
matching_scripts = {}
|
||||
missing_scripts = []
|
||||
remote_scripts = self.supported_scripts[mode]
|
||||
for local_script in alwayson_scripts:
|
||||
match = False
|
||||
for remote_script in remote_scripts:
|
||||
if str.lower(local_script) == str.lower(remote_script):
|
||||
matching_scripts[local_script] = alwayson_scripts[local_script]
|
||||
match = True
|
||||
if not match and str.lower(local_script) != 'distribute':
|
||||
missing_scripts.append(local_script)
|
||||
|
||||
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
|
||||
message = "local script(s): "
|
||||
for script in range(0, len(missing_scripts)):
|
||||
message += f"\[{missing_scripts[script]}]"
|
||||
if script < len(missing_scripts) - 1:
|
||||
message += ', '
|
||||
message += f" seem to be unsupported by worker '{self.label}'\n"
|
||||
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
|
||||
logger.debug(message)
|
||||
elif self.jobs_requested < 1:
|
||||
logger.warning(message)
|
||||
|
||||
return matching_scripts
|
||||
|
||||
def query_memory(self):
|
||||
if self.queried is False:
|
||||
self.queried = True
|
||||
response = self.session.get(self.full_url("memory")).json()
|
||||
try:
|
||||
response = response['cuda']['system'] # all in bytes
|
||||
free_vram = int(response['free']) / (1024 * 1024 * 1024)
|
||||
total_vram = int(response['total']) / (1024 * 1024 * 1024)
|
||||
logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
|
||||
return response['free']
|
||||
except KeyError:
|
||||
try:
|
||||
error = response['cuda']['error']
|
||||
msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}"
|
||||
logger.warning(msg)
|
||||
except KeyError:
|
||||
logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n"
|
||||
f"{response}")
|
||||
|
||||
def prepare_payload(self, payload: dict):
|
||||
# clean anything that is unserializable
|
||||
s_tmax = payload.get('s_tmax', 0.0)
|
||||
if s_tmax is not None and s_tmax > 1e308: # s_tmax can be float('inf')
|
||||
payload['s_tmax'] = 1e308
|
||||
key_blacklist = ['cached_uc', 'cached_c', 'uc', 'c', 'cached_hr_c', 'cached_hr_uc']
|
||||
for k in key_blacklist:
|
||||
payload.pop(k, None)
|
||||
|
||||
# if img2img then we need to b64 encode the init images
|
||||
init_images = payload.get('init_images', None)
|
||||
mode = 'txt2img'
|
||||
if init_images is not None:
|
||||
mode = 'img2img' # for use in checking script compat
|
||||
images = []
|
||||
for image in init_images:
|
||||
images.append(pil_to_64(image))
|
||||
payload['init_images'] = images
|
||||
# if an image mask is present
|
||||
image_mask = payload.get('image_mask', None)
|
||||
if image_mask is not None:
|
||||
payload['mask'] = pil_to_64(image_mask)
|
||||
del payload['image_mask']
|
||||
|
||||
payload['alwayson_scripts'] = self.matching_scripts(payload, mode)
|
||||
# see if there is anything else wrong with serializing to payload
|
||||
try:
|
||||
json.dumps(payload)
|
||||
except Exception:
|
||||
if payload.get('init_images', None):
|
||||
payload['init_images'] = 'TRUNCATED'
|
||||
logger.error(f"Failed to serialize payload: \n{payload}")
|
||||
|
||||
def request(self, payload: dict, option_payload: dict, sync_options: bool):
|
||||
"""
|
||||
Sends an arbitrary amount of requests to a sdwui api depending on the context.
|
||||
|
|
@ -310,6 +394,7 @@ class Worker:
|
|||
sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals'
|
||||
"""
|
||||
eta = None
|
||||
response_queue = queue.Queue()
|
||||
|
||||
try:
|
||||
if self.jobs_requested != 0: # prevent potential hang at startup
|
||||
|
|
@ -327,32 +412,8 @@ class Worker:
|
|||
logger.debug(f"waited {waited}s for worker '{self.label}' to IDLE before consecutive request")
|
||||
if waited >= (0.85 * max_wait):
|
||||
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
||||
|
||||
self.set_state(State.WORKING)
|
||||
|
||||
# query memory available on worker and store for future reference
|
||||
if self.queried is False:
|
||||
self.queried = True
|
||||
memory_response = self.session.get(
|
||||
self.full_url("memory")
|
||||
)
|
||||
memory_response = memory_response.json()
|
||||
try:
|
||||
memory_response = memory_response['cuda']['system'] # all in bytes
|
||||
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
|
||||
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
|
||||
logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
|
||||
self.free_vram = memory_response['free']
|
||||
except KeyError:
|
||||
try:
|
||||
error = memory_response['cuda']['error']
|
||||
msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}"
|
||||
logger.warning(msg)
|
||||
# gradio.Warning("Distributed: "+msg)
|
||||
except KeyError:
|
||||
logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n"
|
||||
f"{memory_response}")
|
||||
|
||||
self.free_vram = self.query_memory()
|
||||
if sync_options is True:
|
||||
self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae'])
|
||||
|
||||
|
|
@ -362,146 +423,76 @@ class Worker:
|
|||
f"{payload['batch_size'] * payload['n_iter']} image(s) "
|
||||
f"at a speed of {self.avg_ipm:.2f} ipm\n")
|
||||
|
||||
self.prepare_payload(payload)
|
||||
|
||||
# remove anything that is not serializable
|
||||
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
||||
s_tmax = payload.get('s_tmax', 0.0)
|
||||
if s_tmax is not None and s_tmax > 1e308:
|
||||
payload['s_tmax'] = 1e308
|
||||
# remove unserializable caches
|
||||
payload.pop('cached_uc', None)
|
||||
payload.pop('cached_c', None)
|
||||
payload.pop('uc', None)
|
||||
payload.pop('c', None)
|
||||
payload.pop('cached_hr_c', None)
|
||||
payload.pop('cached_hr_uc', None)
|
||||
def interruptible_request(response_queue):
|
||||
# TODO shouldn't be this way
|
||||
sampler_index = payload.get('sampler_index', None)
|
||||
sampler_name = payload.get('sampler_name', None)
|
||||
if sampler_index is None:
|
||||
if sampler_name is not None:
|
||||
payload['sampler_index'] = sampler_name
|
||||
|
||||
# if img2img then we need to b64 encode the init images
|
||||
init_images = payload.get('init_images', None)
|
||||
mode = 'txt2img'
|
||||
if init_images is not None:
|
||||
mode = 'img2img' # for use in checking script compat
|
||||
images = []
|
||||
for image in init_images:
|
||||
images.append(pil_to_64(image))
|
||||
payload['init_images'] = images
|
||||
|
||||
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
|
||||
if alwayson_scripts is not None:
|
||||
if len(self.supported_scripts) <= 0:
|
||||
payload['alwayson_scripts'] = {}
|
||||
else:
|
||||
matching_scripts = {}
|
||||
missing_scripts = []
|
||||
remote_scripts = self.supported_scripts[mode]
|
||||
for local_script in alwayson_scripts:
|
||||
match = False
|
||||
for remote_script in remote_scripts:
|
||||
if str.lower(local_script) == str.lower(remote_script):
|
||||
matching_scripts[local_script] = alwayson_scripts[local_script]
|
||||
match = True
|
||||
if not match and str.lower(local_script) != 'distribute':
|
||||
missing_scripts.append(local_script)
|
||||
|
||||
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
|
||||
message = "local script(s): "
|
||||
for script in range(0, len(missing_scripts)):
|
||||
message += f"\[{missing_scripts[script]}]"
|
||||
if script < len(missing_scripts) - 1:
|
||||
message += ', '
|
||||
message += f" seem to be unsupported by worker '{self.label}'\n"
|
||||
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
|
||||
logger.debug(message)
|
||||
elif self.jobs_requested < 1:
|
||||
logger.warning(message)
|
||||
|
||||
payload['alwayson_scripts'] = matching_scripts
|
||||
|
||||
# if an image mask is present
|
||||
image_mask = payload.get('image_mask', None)
|
||||
if image_mask is not None:
|
||||
payload['mask'] = pil_to_64(image_mask)
|
||||
del payload['image_mask']
|
||||
|
||||
# see if there is anything else wrong with serializing to payload
|
||||
try:
|
||||
json.dumps(payload)
|
||||
except Exception:
|
||||
if payload.get('init_images', None):
|
||||
payload['init_images'] = 'TRUNCATED'
|
||||
logger.error(f"Failed to serialize payload: \n{payload}")
|
||||
# gradio.Info("Distributed: failed to serialize payload")
|
||||
|
||||
# the main api requests sent to either the txt2img or img2img route
|
||||
response_queue = queue.Queue()
|
||||
|
||||
def preemptible_request(response_queue):
|
||||
# TODO shouldn't be this way
|
||||
sampler_index = payload.get('sampler_index', None)
|
||||
sampler_name = payload.get('sampler_name', None)
|
||||
if sampler_index is None:
|
||||
if sampler_name is not None:
|
||||
payload['sampler_index'] = sampler_name
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
self.full_url("txt2img") if init_images is None else self.full_url("img2img"),
|
||||
response_queue.put(
|
||||
self.session.post(
|
||||
self.full_url("txt2img") if payload.get('init_images', None) is None else self.full_url("img2img"),
|
||||
json=payload
|
||||
)
|
||||
response_queue.put(response)
|
||||
except Exception as e:
|
||||
response_queue.put(e) # forwarding thrown exceptions to parent thread
|
||||
)
|
||||
except Exception as e:
|
||||
response_queue.put(e) # forwarding thrown exceptions to parent thread
|
||||
|
||||
request_thread = Thread(target=preemptible_request, args=(response_queue,))
|
||||
interrupting = False
|
||||
start = time.time()
|
||||
request_thread.start()
|
||||
while request_thread.is_alive():
|
||||
if interrupting is False and master_state.interrupted is True:
|
||||
self.interrupt()
|
||||
interrupting = True
|
||||
time.sleep(0.5)
|
||||
request_thread = Thread(target=interruptible_request, args=(response_queue,))
|
||||
interrupting = False
|
||||
start = time.time()
|
||||
request_thread.start()
|
||||
while request_thread.is_alive():
|
||||
if interrupting is False and master_state.interrupted is True:
|
||||
self.interrupt()
|
||||
interrupting = True
|
||||
time.sleep(0.5)
|
||||
|
||||
result = response_queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
response = result
|
||||
result = response_queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
response = result
|
||||
|
||||
self.response = response.json()
|
||||
if response.status_code != 200:
|
||||
# try again when remote doesn't support the selected sampler by falling back to Euler a
|
||||
if response.status_code == 404 and self.response['detail'] == "Sampler not found":
|
||||
logger.warning(f"falling back to Euler A sampler for worker {self.label}\n"
|
||||
f"this may mean you should update this worker")
|
||||
payload['sampler_index'] = 'Euler a'
|
||||
payload['sampler_name'] = 'Euler a'
|
||||
self.response = response.json()
|
||||
if response.status_code != 200:
|
||||
# try again when remote doesn't support the selected sampler by falling back to Euler a
|
||||
if response.status_code == 404 and self.response['detail'] == "Sampler not found":
|
||||
logger.warning(f"falling back to Euler A sampler for worker {self.label}\n"
|
||||
f"this may mean you should update this worker")
|
||||
payload['sampler_index'] = 'Euler a'
|
||||
payload['sampler_name'] = 'Euler a'
|
||||
|
||||
second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,))
|
||||
second_attempt.start()
|
||||
second_attempt.join()
|
||||
return
|
||||
second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,))
|
||||
second_attempt.start()
|
||||
second_attempt.join()
|
||||
return
|
||||
|
||||
self.response = None
|
||||
raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self)
|
||||
self.response = None
|
||||
raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self)
|
||||
|
||||
# update list of ETA accuracy if state is valid
|
||||
if self.benchmarked and not self.state == State.INTERRUPTED:
|
||||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
# update list of ETA accuracy if state is valid
|
||||
if self.benchmarked and not self.state == State.INTERRUPTED:
|
||||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
|
||||
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
|
||||
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
|
||||
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
|
||||
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||
if abs(variance) < 500:
|
||||
# check if there are already 5 samples and if so, remove the oldest
|
||||
# this should help adjust to the user changing tasks
|
||||
if len(self.eta_percent_error) > 4:
|
||||
self.eta_percent_error.pop(0)
|
||||
else: # normal case
|
||||
self.eta_percent_error.append(variance)
|
||||
else:
|
||||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||
if abs(variance) < 500:
|
||||
# check if there are already 5 samples and if so, remove the oldest
|
||||
# this should help adjust to the user changing tasks
|
||||
if len(self.eta_percent_error) > 4:
|
||||
self.eta_percent_error.pop(0)
|
||||
else: # normal case
|
||||
self.eta_percent_error.append(variance)
|
||||
else:
|
||||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
|
||||
except Exception as e:
|
||||
raise WorkerException('', worker=self, exception=e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue