Improve error messages, cleanup useless error checks, properly propagate errors
parent
cfae14f03b
commit
99e4fb87c2
103
scripts/main.py
103
scripts/main.py
|
|
@ -32,6 +32,9 @@ settings_file = os.path.join(scripts.basedir(), "settings.json")
|
|||
class FakeModel:
|
||||
sd_model_hash=""
|
||||
|
||||
class StableHordeError(Exception):
|
||||
pass
|
||||
|
||||
class Main(scripts.Script):
|
||||
TITLE = "Run on Stable Horde"
|
||||
SAMPLERS = {
|
||||
|
|
@ -82,11 +85,10 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
|
||||
assert models.status_code == 200, "Status Code: {} (expected {})".format(models.status_code, 200)
|
||||
models = models.json()
|
||||
models.sort(key=lambda m: (-m["count"], m["name"]))
|
||||
models = ["{} ({})".format(m["name"], m["count"]) for m in models]
|
||||
except (requests.ConnectionError, AssertionError):
|
||||
except requests.ConnectionError:
|
||||
models = []
|
||||
|
||||
models.insert(0, "Random")
|
||||
|
|
@ -287,8 +289,6 @@ class Main(scripts.Script):
|
|||
x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
|
||||
if x_samples_ddim is None:
|
||||
del x_samples_ddim
|
||||
devices.torch_gc()
|
||||
break
|
||||
|
||||
if p.scripts is not None:
|
||||
|
|
@ -410,9 +410,7 @@ class Main(scripts.Script):
|
|||
payload["params"]["post_processing"] = post_processing
|
||||
|
||||
if shared.state.skipped or shared.state.interrupted:
|
||||
return
|
||||
|
||||
id = None
|
||||
return (None, None)
|
||||
|
||||
try:
|
||||
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
|
||||
|
|
@ -425,84 +423,45 @@ class Main(scripts.Script):
|
|||
if shared.state.skipped or shared.state.interrupted:
|
||||
return self.cancel_process_batch_horde(id)
|
||||
|
||||
status = None
|
||||
|
||||
try:
|
||||
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
|
||||
assert status.status_code == 200, "Status Code: {} (expected {})".format(status.status_code, 200)
|
||||
status = status.json()
|
||||
shared.state.sampling_step = status["finished"]
|
||||
|
||||
if status["done"]:
|
||||
shared.state.sampling_step = shared.state.sampling_steps
|
||||
images = None
|
||||
|
||||
try:
|
||||
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
|
||||
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return (images, models)
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
print(e)
|
||||
|
||||
if images is not None:
|
||||
images = images.json()
|
||||
print(images["message"])
|
||||
|
||||
break
|
||||
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return (images, models)
|
||||
elif status["faulted"]:
|
||||
print("faulted")
|
||||
break
|
||||
raise StableHordeError("This request caused an internal server error and could not be completed.")
|
||||
elif not status["is_possible"]:
|
||||
print("not is_possible")
|
||||
break
|
||||
raise StableHordeError("This request will not be able to be completed with the pool of workers currently available.")
|
||||
else:
|
||||
time.sleep(1)
|
||||
except requests.Timeout:
|
||||
time.sleep(1)
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
print(e)
|
||||
|
||||
if status is not None:
|
||||
status = status.json()
|
||||
print(status["message"])
|
||||
|
||||
return self.cancel_process_batch_horde(id)
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
print(payload)
|
||||
print(e)
|
||||
|
||||
if id is not None:
|
||||
id = id.json()
|
||||
print(id["message"])
|
||||
except AssertionError as e:
|
||||
id = id.json()
|
||||
raise StableHordeError(id["message"])
|
||||
|
||||
def cancel_process_batch_horde(self, id):
|
||||
images = None
|
||||
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
|
||||
try:
|
||||
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
|
||||
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
|
||||
if len(images) > 0:
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return (images, models)
|
||||
except requests.Timeout:
|
||||
return
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
print(e)
|
||||
|
||||
if images is not None:
|
||||
images = images.json()
|
||||
print(images["message"])
|
||||
if len(images) > 0:
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return (images, models)
|
||||
else:
|
||||
return (None, None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue