diff --git a/scripts/main.py b/scripts/main.py index 4a9c726..e31e28b 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -32,6 +32,9 @@ settings_file = os.path.join(scripts.basedir(), "settings.json") class FakeModel: sd_model_hash="" +class StableHordeError(Exception): + pass + class Main(scripts.Script): TITLE = "Run on Stable Horde" SAMPLERS = { @@ -82,11 +85,10 @@ class Main(scripts.Script): try: models = requests.get("{}/v2/status/models".format(self.api_endpoint)) - assert models.status_code == 200, "Status Code: {} (expected {})".format(models.status_code, 200) models = models.json() models.sort(key=lambda m: (-m["count"], m["name"])) models = ["{} ({})".format(m["name"], m["count"]) for m in models] - except (requests.ConnectionError, AssertionError): + except requests.ConnectionError: models = [] models.insert(0, "Random") @@ -287,8 +289,6 @@ class Main(scripts.Script): x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0]) if x_samples_ddim is None: - del x_samples_ddim - devices.torch_gc() break if p.scripts is not None: @@ -410,9 +410,7 @@ class Main(scripts.Script): payload["params"]["post_processing"] = post_processing if shared.state.skipped or shared.state.interrupted: - return - - id = None + return (None, None) try: id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload) @@ -425,84 +423,45 @@ class Main(scripts.Script): if shared.state.skipped or shared.state.interrupted: return self.cancel_process_batch_horde(id) - status = None - try: status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1) - assert status.status_code == 200, "Status Code: {} (expected {})".format(status.status_code, 200) status = status.json() shared.state.sampling_step = status["finished"] if status["done"]: shared.state.sampling_step = shared.state.sampling_steps - images = None - - try: - images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id)) - assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200) - images = images.json() - images = images["generations"] - models = [image["model"] for image in images] - images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] - images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] - images = [torch.from_numpy(image) for image in images] - images = torch.stack(images).to(shared.device) - return (images, models) - except (requests.ConnectionError, AssertionError) as e: - print(e) - - if images is not None: - images = images.json() - print(images["message"]) - - break + images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id)) + images = images.json() + images = images["generations"] + models = [image["model"] for image in images] + images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] + images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] + images = [torch.from_numpy(image) for image in images] + images = torch.stack(images).to(shared.device) + return (images, models) elif status["faulted"]: - print("faulted") - break + raise StableHordeError("This request caused an internal server error and could not be completed.") elif not status["is_possible"]: - print("not is_possible") - break + raise StableHordeError("This request will not be able to be completed with the pool of workers currently available.") else: time.sleep(1) except requests.Timeout: time.sleep(1) - except (requests.ConnectionError, AssertionError) as e: - print(e) - - if status is not None: - status = status.json() - print(status["message"]) - - return self.cancel_process_batch_horde(id) - except (requests.ConnectionError, AssertionError) as e: - print(payload) - print(e) - - if id is not None: - id = id.json() - print(id["message"]) + except AssertionError as e: + id = id.json() + raise StableHordeError(id["message"]) def cancel_process_batch_horde(self, id): - images = None + images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60) + images = images.json() + images = images["generations"] + models = [image["model"] for image in images] + images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] + images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] + images = [torch.from_numpy(image) for image in images] - try: - images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60) - assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200) - images = images.json() - images = images["generations"] - models = [image["model"] for image in images] - images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] - images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] - images = [torch.from_numpy(image) for image in images] - - if len(images) > 0: - images = torch.stack(images).to(shared.device) - return (images, models) - except requests.Timeout: - return - except (requests.ConnectionError, AssertionError) as e: - print(e) - - if images is not None: - images = images.json() - print(images["message"]) + if len(images) > 0: + images = torch.stack(images).to(shared.device) + return (images, models) + else: + return (None, None)