beginning of refactoring processed_inject_image

master^2
papuSpartan 2024-09-29 23:08:52 -05:00
parent 8a810b7714
commit 7349ffd506
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
1 changed files with 51 additions and 39 deletions

View File

@ -72,54 +72,62 @@ class DistributedScript(scripts.Script):
# return some components that should be exposed to the api
return components
def add_to_gallery(self, pp, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
def api_to_internal(self, job):
# takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
def processed_inject_image(image, info_index, save_path_override=None, grid=False, job=None):
image_params: json = job.worker.response['parameters']
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
num_response_images = image_params["batch_size"] * image_params["n_iter"]
seed = None
subseed = None
negative_prompt = None
pos_prompt = None
image_params: json = job.worker.response['parameters']
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = [], [], [], [], []
for i in range(len(job.worker.response["images"])):
try:
if num_response_images > 1:
seed = image_info_post['all_seeds'][info_index]
subseed = image_info_post['all_subseeds'][info_index]
negative_prompt = image_info_post['all_negative_prompts'][info_index]
pos_prompt = image_info_post['all_prompts'][info_index]
else:
seed = image_info_post['seed']
subseed = image_info_post['subseed']
negative_prompt = image_info_post['negative_prompt']
pos_prompt = image_info_post['prompt']
if image_params["batch_size"] * image_params["n_iter"] > 1:
all_seeds.append(image_info_post['all_seeds'][i])
all_subseeds.append(image_info_post['all_subseeds'][i])
all_negative_prompts.append(image_info_post['all_negative_prompts'][i])
all_prompts.append(image_info_post['all_prompts'][i])
else: # only a single image received
all_seeds.append(image_info_post['seed'])
all_subseeds.append(image_info_post['subseed'])
all_negative_prompts.append(image_info_post['negative_prompt'])
all_prompts.append(image_info_post['prompt'])
except IndexError:
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
processed_inject_image(image=image, info_index=0, job=job)
return
p.seeds.append(seed)
p.subseeds.append(subseed)
p.negative_prompts.append(negative_prompt)
p.prompts.append(pos_prompt)
# # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
# logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
# self.processed_inject_image(image=image, info_index=0, job=job, p=p)
# return
logger.critical(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
continue
# parse image
image_bytes = base64.b64decode(job.worker.response["images"][i])
image = Image.open(io.BytesIO(image_bytes))
transform = ToTensor()
pp.images.append(transform(image)) # actual received image
images.append(transform(image))
return all_seeds, all_subseeds, all_negative_prompts, all_prompts, images
def processed_inject_image(self, job, p, pp):
all_seeds, all_subseeds, all_negative_prompts, all_prompts, images = self.api_to_internal(job)
p.seeds.extend(all_seeds)
p.subseeds.extend(all_subseeds)
p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts)
for i, image in enumerate(images):
pp.images.append(image) # add one image to the gallery
# modules.ui_common -> update_generation_info renders to html below gallery
images_per_batch = p.n_iter * p.batch_size
# zero-indexed position of image in gallery (so including master/local results)
true_image_pos = (len(pp.images) - 1) + (not p.do_not_save_grid)
num_remote_images = images_per_batch * p.batch_size
num_remote_images = images_per_batch * p.batch_size # the **expected** amount of remote images
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
f"info-index: {info_index}")
logger.debug(f"image {true_image_pos + 1}/{(self.world.p.batch_size * p.n_iter) + (not p.do_not_save_grid) + 1}, "
f"info-index: fix me")
if self.world.thin_client_mode:
p.all_negative_prompts = pp.all_negative_prompts
@ -127,6 +135,9 @@ class DistributedScript(scripts.Script):
# saves final position of image in gallery so that we can later modify the correct infotext
job.gallery_map.append(true_image_pos)
def add_to_gallery(self, pp, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - self.master_start
logger.debug(f"Took master {master_elapsed:.2f}s")
@ -170,12 +181,13 @@ class DistributedScript(scripts.Script):
continue
# visibly add work from workers to the image gallery
for i in range(0, len(images)):
image_bytes = base64.b64decode(images[i])
image = Image.open(io.BytesIO(image_bytes))
# for i in range(0, len(images)):
# image_bytes = base64.b64decode(images[i])
# image = Image.open(io.BytesIO(image_bytes))
# inject image
processed_inject_image(image=image, info_index=i, job=job)
# # inject image
# self.processed_inject_image(image=image, info_index=i, job=job, p=p, pp=pp)
self.processed_inject_image(job, p, pp)
if donor_worker is None:
logger.critical("couldn't collect any responses, the extension will have no effect")