diff --git a/scripts/main.py b/scripts/main.py index 4fd46cf..f480574 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -185,6 +185,7 @@ class Main(scripts.Script): def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 + # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L463-L490 stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()} @@ -212,6 +213,7 @@ class Main(scripts.Script): def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 + # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L493-L687 fake_model = FakeModel(model) @@ -418,24 +420,28 @@ class Main(scripts.Script): return (None, None) try: - id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload) + session = requests.Session() + id = session.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload) assert id.status_code == 202, "Status Code: {} (expected {})".format(id.status_code, 202) id = id.json() id = id["id"] - shared.state.sampling_steps = p.batch_size + shared.state.sampling_steps = 0 + start = time.time() while True: if shared.state.skipped or shared.state.interrupted: return self.cancel_process_batch_horde(id) try: - status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1) + status = session.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1) status = status.json() - shared.state.sampling_step = status["finished"] + elapsed = int(time.time() - start) + shared.state.sampling_steps = elapsed + status["wait_time"] + shared.state.sampling_step = elapsed if status["done"]: - shared.state.sampling_step = shared.state.sampling_steps - images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id)) + shared.state.sampling_steps = shared.state.sampling_step + images = session.get("{}/v2/generate/status/{}".format(self.api_endpoint, id)) images = images.json() images = images["generations"] models = [image["model"] for image in images]