simplify worker.request by extraction

dev
papuSpartan 2024-11-22 10:01:55 -06:00
parent f819f55743
commit 61aa5e5e06
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
1 changed files with 143 additions and 152 deletions

View File

@ -300,6 +300,90 @@ class Worker:
return eta
def matching_scripts(self, payload: dict, mode) -> dict:
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
if alwayson_scripts is not None:
if len(self.supported_scripts) <= 0:
payload['alwayson_scripts'] = {}
else:
matching_scripts = {}
missing_scripts = []
remote_scripts = self.supported_scripts[mode]
for local_script in alwayson_scripts:
match = False
for remote_script in remote_scripts:
if str.lower(local_script) == str.lower(remote_script):
matching_scripts[local_script] = alwayson_scripts[local_script]
match = True
if not match and str.lower(local_script) != 'distribute':
missing_scripts.append(local_script)
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
message = "local script(s): "
for script in range(0, len(missing_scripts)):
message += f"\[{missing_scripts[script]}]"
if script < len(missing_scripts) - 1:
message += ', '
message += f" seem to be unsupported by worker '{self.label}'\n"
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
logger.debug(message)
elif self.jobs_requested < 1:
logger.warning(message)
return matching_scripts
def query_memory(self):
if self.queried is False:
self.queried = True
response = self.session.get(self.full_url("memory")).json()
try:
response = response['cuda']['system'] # all in bytes
free_vram = int(response['free']) / (1024 * 1024 * 1024)
total_vram = int(response['total']) / (1024 * 1024 * 1024)
logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
return response['free']
except KeyError:
try:
error = response['cuda']['error']
msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}"
logger.warning(msg)
except KeyError:
logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n"
f"{response}")
def prepare_payload(self, payload: dict):
# clean anything that is unserializable
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax is not None and s_tmax > 1e308: # s_tmax can be float('inf')
payload['s_tmax'] = 1e308
key_blacklist = ['cached_uc', 'cached_c', 'uc', 'c', 'cached_hr_c', 'cached_hr_uc']
for k in key_blacklist:
payload.pop(k, None)
# if img2img then we need to b64 encode the init images
init_images = payload.get('init_images', None)
mode = 'txt2img'
if init_images is not None:
mode = 'img2img' # for use in checking script compat
images = []
for image in init_images:
images.append(pil_to_64(image))
payload['init_images'] = images
# if an image mask is present
image_mask = payload.get('image_mask', None)
if image_mask is not None:
payload['mask'] = pil_to_64(image_mask)
del payload['image_mask']
payload['alwayson_scripts'] = self.matching_scripts(payload, mode)
# see if there is anything else wrong with serializing to payload
try:
json.dumps(payload)
except Exception:
if payload.get('init_images', None):
payload['init_images'] = 'TRUNCATED'
logger.error(f"Failed to serialize payload: \n{payload}")
def request(self, payload: dict, option_payload: dict, sync_options: bool):
"""
Sends an arbitrary amount of requests to a sdwui api depending on the context.
@ -310,6 +394,7 @@ class Worker:
sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals'
"""
eta = None
response_queue = queue.Queue()
try:
if self.jobs_requested != 0: # prevent potential hang at startup
@ -327,32 +412,8 @@ class Worker:
logger.debug(f"waited {waited}s for worker '{self.label}' to IDLE before consecutive request")
if waited >= (0.85 * max_wait):
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
self.set_state(State.WORKING)
# query memory available on worker and store for future reference
if self.queried is False:
self.queried = True
memory_response = self.session.get(
self.full_url("memory")
)
memory_response = memory_response.json()
try:
memory_response = memory_response['cuda']['system'] # all in bytes
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = memory_response['free']
except KeyError:
try:
error = memory_response['cuda']['error']
msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}"
logger.warning(msg)
# gradio.Warning("Distributed: "+msg)
except KeyError:
logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n"
f"{memory_response}")
self.free_vram = self.query_memory()
if sync_options is True:
self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae'])
@ -362,146 +423,76 @@ class Worker:
f"{payload['batch_size'] * payload['n_iter']} image(s) "
f"at a speed of {self.avg_ipm:.2f} ipm\n")
self.prepare_payload(payload)
# remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308
# remove unserializable caches
payload.pop('cached_uc', None)
payload.pop('cached_c', None)
payload.pop('uc', None)
payload.pop('c', None)
payload.pop('cached_hr_c', None)
payload.pop('cached_hr_uc', None)
def interruptible_request(response_queue):
# TODO shouldn't be this way
sampler_index = payload.get('sampler_index', None)
sampler_name = payload.get('sampler_name', None)
if sampler_index is None:
if sampler_name is not None:
payload['sampler_index'] = sampler_name
# if img2img then we need to b64 encode the init images
init_images = payload.get('init_images', None)
mode = 'txt2img'
if init_images is not None:
mode = 'img2img' # for use in checking script compat
images = []
for image in init_images:
images.append(pil_to_64(image))
payload['init_images'] = images
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
if alwayson_scripts is not None:
if len(self.supported_scripts) <= 0:
payload['alwayson_scripts'] = {}
else:
matching_scripts = {}
missing_scripts = []
remote_scripts = self.supported_scripts[mode]
for local_script in alwayson_scripts:
match = False
for remote_script in remote_scripts:
if str.lower(local_script) == str.lower(remote_script):
matching_scripts[local_script] = alwayson_scripts[local_script]
match = True
if not match and str.lower(local_script) != 'distribute':
missing_scripts.append(local_script)
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
message = "local script(s): "
for script in range(0, len(missing_scripts)):
message += f"\[{missing_scripts[script]}]"
if script < len(missing_scripts) - 1:
message += ', '
message += f" seem to be unsupported by worker '{self.label}'\n"
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
logger.debug(message)
elif self.jobs_requested < 1:
logger.warning(message)
payload['alwayson_scripts'] = matching_scripts
# if an image mask is present
image_mask = payload.get('image_mask', None)
if image_mask is not None:
payload['mask'] = pil_to_64(image_mask)
del payload['image_mask']
# see if there is anything else wrong with serializing to payload
try:
json.dumps(payload)
except Exception:
if payload.get('init_images', None):
payload['init_images'] = 'TRUNCATED'
logger.error(f"Failed to serialize payload: \n{payload}")
# gradio.Info("Distributed: failed to serialize payload")
# the main api requests sent to either the txt2img or img2img route
response_queue = queue.Queue()
def preemptible_request(response_queue):
# TODO shouldn't be this way
sampler_index = payload.get('sampler_index', None)
sampler_name = payload.get('sampler_name', None)
if sampler_index is None:
if sampler_name is not None:
payload['sampler_index'] = sampler_name
try:
response = self.session.post(
self.full_url("txt2img") if init_images is None else self.full_url("img2img"),
response_queue.put(
self.session.post(
self.full_url("txt2img") if payload.get('init_images', None) is None else self.full_url("img2img"),
json=payload
)
response_queue.put(response)
except Exception as e:
response_queue.put(e) # forwarding thrown exceptions to parent thread
)
except Exception as e:
response_queue.put(e) # forwarding thrown exceptions to parent thread
request_thread = Thread(target=preemptible_request, args=(response_queue,))
interrupting = False
start = time.time()
request_thread.start()
while request_thread.is_alive():
if interrupting is False and master_state.interrupted is True:
self.interrupt()
interrupting = True
time.sleep(0.5)
request_thread = Thread(target=interruptible_request, args=(response_queue,))
interrupting = False
start = time.time()
request_thread.start()
while request_thread.is_alive():
if interrupting is False and master_state.interrupted is True:
self.interrupt()
interrupting = True
time.sleep(0.5)
result = response_queue.get()
if isinstance(result, Exception):
raise result
response = result
result = response_queue.get()
if isinstance(result, Exception):
raise result
response = result
self.response = response.json()
if response.status_code != 200:
# try again when remote doesn't support the selected sampler by falling back to Euler a
if response.status_code == 404 and self.response['detail'] == "Sampler not found":
logger.warning(f"falling back to Euler A sampler for worker {self.label}\n"
f"this may mean you should update this worker")
payload['sampler_index'] = 'Euler a'
payload['sampler_name'] = 'Euler a'
self.response = response.json()
if response.status_code != 200:
# try again when remote doesn't support the selected sampler by falling back to Euler a
if response.status_code == 404 and self.response['detail'] == "Sampler not found":
logger.warning(f"falling back to Euler A sampler for worker {self.label}\n"
f"this may mean you should update this worker")
payload['sampler_index'] = 'Euler a'
payload['sampler_name'] = 'Euler a'
second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,))
second_attempt.start()
second_attempt.join()
return
second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,))
second_attempt.start()
second_attempt.join()
return
self.response = None
raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self)
self.response = None
raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self)
# update list of ETA accuracy if state is valid
if self.benchmarked and not self.state == State.INTERRUPTED:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
# update list of ETA accuracy if state is valid
if self.benchmarked and not self.state == State.INTERRUPTED:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n"
f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
# if the variance is greater than 500% then we ignore it to prevent variation inflation
if abs(variance) < 500:
# check if there are already 5 samples and if so, remove the oldest
# this should help adjust to the user changing tasks
if len(self.eta_percent_error) > 4:
self.eta_percent_error.pop(0)
else: # normal case
self.eta_percent_error.append(variance)
else:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
# if the variance is greater than 500% then we ignore it to prevent variation inflation
if abs(variance) < 500:
# check if there are already 5 samples and if so, remove the oldest
# this should help adjust to the user changing tasks
if len(self.eta_percent_error) > 4:
self.eta_percent_error.pop(0)
else: # normal case
self.eta_percent_error.append(variance)
else:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e:
raise WorkerException('', worker=self, exception=e)