once again inject infotext from remote results into gallery

master^2
papuSpartan 2024-09-27 20:21:11 -05:00
parent 32bc086ec6
commit 8a810b7714
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 24 additions and 47 deletions

View File

@ -75,9 +75,9 @@ class DistributedScript(scripts.Script):
def add_to_gallery(self, pp, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
image_params: json = response['parameters']
image_info_post: json = json.loads(response["info"]) # image info known after processing
def processed_inject_image(image, info_index, save_path_override=None, grid=False, job=None):
image_params: json = job.worker.response['parameters']
image_info_post: json = json.loads(job.worker.response["info"]) # image info known after processing
num_response_images = image_params["batch_size"] * image_params["n_iter"]
seed = None
@ -99,7 +99,7 @@ class DistributedScript(scripts.Script):
except IndexError:
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
logger.debug(f"Image at index {i} for '{job.worker.label}' was missing some post-generation data")
processed_inject_image(image=image, info_index=0, response=response)
processed_inject_image(image=image, info_index=0, job=job)
return
p.seeds.append(seed)
@ -110,12 +110,10 @@ class DistributedScript(scripts.Script):
transform = ToTensor()
pp.images.append(transform(image)) # actual received image
# generate info-text string
# modules.ui_common -> update_generation_info renders to html below gallery
images_per_batch = p.n_iter * p.batch_size
# zero-indexed position of image in total batch (so including master results)
true_image_pos = len(pp.images) - 1
# zero-indexed position of image in gallery (so including master/local results)
true_image_pos = (len(pp.images) - 1) + (not p.do_not_save_grid)
num_remote_images = images_per_batch * p.batch_size
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
@ -126,26 +124,8 @@ class DistributedScript(scripts.Script):
if self.world.thin_client_mode:
p.all_negative_prompts = pp.all_negative_prompts
# try:
# info_text = image_info_post['infotexts'][i]
# except IndexError:
# if not grid:
# logger.warning(f"image {true_image_pos + 1} was missing info-text")
# info_text = processed.infotexts[0]
# info_text += f", Worker Label: {job.worker.label}"
# processed.infotexts.append(info_text)
# automatically save received image to local disk if desired
#if cmd_opts.distributed_remotes_autosave:
# save_image(
# image=image,
# path=p.outpath_samples if save_path_override is None else save_path_override,
# basename="",
# seed=seed,
# prompt=pos_prompt,
# info=info_text,
# extension=opts.samples_format
# )
# saves final position of image in gallery so that we can later modify the correct infotext
job.gallery_map.append(true_image_pos)
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - self.master_start
@ -195,15 +175,12 @@ class DistributedScript(scripts.Script):
image = Image.open(io.BytesIO(image_bytes))
# inject image
processed_inject_image(image=image, info_index=i, response=job.worker.response)
processed_inject_image(image=image, info_index=i, job=job)
if donor_worker is None:
logger.critical("couldn't collect any responses, the extension will have no effect")
return
# cleanup after we're doing using all the responses
for worker in self.world.get_workers():
worker.response = None
p.batch_size = len(pp.images)
webui_state.textinfo = ""
@ -357,21 +334,6 @@ class DistributedScript(scripts.Script):
self.runs_since_init += 1
return
# def postprocess(self, p, processed, *args):
# is_img2img = getattr(p, 'init_images', False)
# if is_img2img and self.world.enabled_i2i is False:
# return
# elif not is_img2img and self.world.enabled is False:
# return
# if self.master_start is not None:
# self.add_to_gallery(p=p, processed=processed)
# # restore process_images_inner if it was monkey-patched
# processing.process_images_inner = self.original_process_images_inner
# # save any dangling state to prevent load_config in next iteration overwriting it
# self.world.save_config()
def postprocess_batch_list(self, p, pp, *args, **kwargs):
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
@ -382,6 +344,20 @@ class DistributedScript(scripts.Script):
if self.master_start is not None:
self.add_to_gallery(p=p, pp=pp)
def postprocess(self, p, processed, *args):
for job in self.world.jobs:
if job.worker.response is not None:
for i, v in enumerate(job.gallery_map):
infotext = json.loads(job.worker.response['info'])['infotexts'][i]
logger.debug(f"replacing image {v}'s infotext with\n"
f"> '{infotext}'")
infotext += f", Worker Label: {job.worker.label}"
processed.infotexts[v] = infotext
# cleanup
for worker in self.world.get_workers():
worker.response = None
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it

View File

@ -46,6 +46,7 @@ class Job:
self.complementary: bool = False
self.step_override = None
self.thread = None
self.gallery_map: List[int] = []
def __str__(self):
prefix = ''