fix certain worker disconnects caused by making requests while the same worker is loading weights

master
unknown 2023-12-29 15:06:57 -06:00
parent a047ab832d
commit c228315b1b
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
1 changed files with 21 additions and 3 deletions

View File

@ -277,6 +277,21 @@ class Worker:
eta = None
try:
# if state is already WORKING then weights may be loading on worker
# prevents issue where model override loads a large model and consecutive requests timeout
max_wait = 30
waited = 0
while self.state == State.WORKING:
if waited >= max_wait:
break
time.sleep(1)
waited += 1
if waited != 0:
logger.debug(f"waited {waited}s for worker '{self.label}' to IDLE before consecutive request")
if waited >= (0.85 * max_wait):
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
self.state = State.WORKING
# query memory available on worker and store for future reference
@ -304,8 +319,6 @@ class Worker:
if sync_options is True:
self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae'])
# TODO api returns 200 even if it fails to successfully set the checkpoint so we will have to make a
# second GET to see if everything loaded...
if self.benchmarked:
eta = self.batch_eta(payload=payload) * payload['n_iter']
@ -535,7 +548,7 @@ class Worker:
ipm_sum += ipm_result
avg_ipm_result = ipm_sum / samples
logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result}")
logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}")
self.avg_ipm = avg_ipm_result
self.response = None
self.benchmarked = True
@ -632,14 +645,19 @@ class Worker:
if vae is not None:
payload['sd_vae'] = vae
self.state = State.WORKING
start = time.time()
response = self.session.post(
self.full_url("options"),
json=payload
)
elapsed = time.time() - start
self.state = State.IDLE
if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'")
else:
logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s")
self.loaded_model = model_name
if vae is not None:
self.loaded_vae = vae