Improve error messages, cleanup useless error checks, properly propagate errors

main
natanjunges 2023-01-11 21:20:20 -03:00
parent cfae14f03b
commit 99e4fb87c2
1 changed files with 31 additions and 72 deletions

View File

@ -32,6 +32,9 @@ settings_file = os.path.join(scripts.basedir(), "settings.json")
class FakeModel:
sd_model_hash=""
class StableHordeError(Exception):
pass
class Main(scripts.Script):
TITLE = "Run on Stable Horde"
SAMPLERS = {
@ -82,11 +85,10 @@ class Main(scripts.Script):
try:
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
assert models.status_code == 200, "Status Code: {} (expected {})".format(models.status_code, 200)
models = models.json()
models.sort(key=lambda m: (-m["count"], m["name"]))
models = ["{} ({})".format(m["name"], m["count"]) for m in models]
except (requests.ConnectionError, AssertionError):
except requests.ConnectionError:
models = []
models.insert(0, "Random")
@ -287,8 +289,6 @@ class Main(scripts.Script):
x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
if x_samples_ddim is None:
del x_samples_ddim
devices.torch_gc()
break
if p.scripts is not None:
@ -410,9 +410,7 @@ class Main(scripts.Script):
payload["params"]["post_processing"] = post_processing
if shared.state.skipped or shared.state.interrupted:
return
id = None
return (None, None)
try:
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
@ -425,84 +423,45 @@ class Main(scripts.Script):
if shared.state.skipped or shared.state.interrupted:
return self.cancel_process_batch_horde(id)
status = None
try:
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
assert status.status_code == 200, "Status Code: {} (expected {})".format(status.status_code, 200)
status = status.json()
shared.state.sampling_step = status["finished"]
if status["done"]:
shared.state.sampling_step = shared.state.sampling_steps
images = None
try:
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
images = torch.stack(images).to(shared.device)
return (images, models)
except (requests.ConnectionError, AssertionError) as e:
print(e)
if images is not None:
images = images.json()
print(images["message"])
break
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
images = torch.stack(images).to(shared.device)
return (images, models)
elif status["faulted"]:
print("faulted")
break
raise StableHordeError("This request caused an internal server error and could not be completed.")
elif not status["is_possible"]:
print("not is_possible")
break
raise StableHordeError("This request will not be able to be completed with the pool of workers currently available.")
else:
time.sleep(1)
except requests.Timeout:
time.sleep(1)
except (requests.ConnectionError, AssertionError) as e:
print(e)
if status is not None:
status = status.json()
print(status["message"])
return self.cancel_process_batch_horde(id)
except (requests.ConnectionError, AssertionError) as e:
print(payload)
print(e)
if id is not None:
id = id.json()
print(id["message"])
except AssertionError as e:
id = id.json()
raise StableHordeError(id["message"])
def cancel_process_batch_horde(self, id):
images = None
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
try:
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
if len(images) > 0:
images = torch.stack(images).to(shared.device)
return (images, models)
except requests.Timeout:
return
except (requests.ConnectionError, AssertionError) as e:
print(e)
if images is not None:
images = images.json()
print(images["message"])
if len(images) > 0:
images = torch.stack(images).to(shared.device)
return (images, models)
else:
return (None, None)