Reuse connection, improve generation progress report

main
natanjunges 2023-01-14 02:32:09 -03:00
parent 2ed95cac9c
commit 2c6b407e69
1 changed files with 12 additions and 6 deletions

View File

@ -185,6 +185,7 @@ class Main(scripts.Script):
def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L463-L490
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
@ -212,6 +213,7 @@ class Main(scripts.Script):
def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L493-L687
fake_model = FakeModel(model)
@ -418,24 +420,28 @@ class Main(scripts.Script):
return (None, None)
try:
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
session = requests.Session()
id = session.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
assert id.status_code == 202, "Status Code: {} (expected {})".format(id.status_code, 202)
id = id.json()
id = id["id"]
shared.state.sampling_steps = p.batch_size
shared.state.sampling_steps = 0
start = time.time()
while True:
if shared.state.skipped or shared.state.interrupted:
return self.cancel_process_batch_horde(id)
try:
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
status = session.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
status = status.json()
shared.state.sampling_step = status["finished"]
elapsed = int(time.time() - start)
shared.state.sampling_steps = elapsed + status["wait_time"]
shared.state.sampling_step = elapsed
if status["done"]:
shared.state.sampling_step = shared.state.sampling_steps
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
shared.state.sampling_steps = shared.state.sampling_step
images = session.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]