update thin-client mode

master^2
papuSpartan 2024-10-25 21:55:49 -05:00
parent 622827aab5
commit 64bc137d65
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 33 additions and 17 deletions

View File

@ -116,19 +116,14 @@ class DistributedScript(scripts.Script):
p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts)
num_local = self.world.p.n_iter * self.world.p.batch_size + opts.return_grid
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
num_injected = len(pp.images) - self.world.p.batch_size
for i, image in enumerate(images):
# modules.ui_common -> update_generation_info renders to html below gallery
# TODO probably shouldn't be here
if self.world.thin_client_mode:
p.all_negative_prompts = pp.all_negative_prompts
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
job.gallery_map.append(gallery_index) # so we know where to edit infotext
pp.images.append(image)
logger.debug(f"image {gallery_index + 1}/{self.world.num_gallery()}")
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
def update_gallery(self, pp, p):
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
@ -328,13 +323,11 @@ class DistributedScript(scripts.Script):
p.batch_size = self.world.master_job().batch_size
self.master_start = time.time()
# generate images assigned to local machine
# p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
self.runs_since_init += 1
return
def postprocess_batch_list(self, p, pp, *args, **kwargs):
if p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
return
is_img2img = getattr(p, 'init_images', False)

View File

@ -609,6 +609,7 @@ class Worker:
self.full_url("memory"),
timeout=3
)
self.response = response
return response.status_code == 200
except requests.exceptions.ConnectionError as e:

View File

@ -21,6 +21,9 @@ from .worker import Worker, State
from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing
from modules import progress
from modules.scripts import PostprocessBatchListArgs
from torchvision.transforms import ToPILImage
from modules.images import image_grid
class NotBenchmarked(Exception):
@ -562,15 +565,31 @@ class World:
# save original process_images_inner for later so we can restore once we're done
logger.debug(f"bypassing local generation completely")
def process_images_inner_bypass(p) -> processing.Processed:
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
pp = PostprocessBatchListArgs(images=[])
self.p.scripts.postprocess_batch_list(p, pp)
processed = processing.Processed(p, [], p.seed, info="")
processed.all_prompts = []
processed.all_seeds = []
processed.all_subseeds = []
processed.all_negative_prompts = []
processed.infotexts = []
processed.prompt = None
processed.all_prompts = p.prompts
processed.all_seeds = p.seeds
processed.all_subseeds = p.subseeds
processed.all_negative_prompts = p.negative_prompts
processed.images = pp.images
processed.infotexts = [''] * self.num_requested()
transform = ToPILImage()
for i, image in enumerate(processed.images):
processed.images[i] = transform(image)
self.p.scripts.postprocess(p, processed)
# generate grid if enabled
if shared.opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed.images.insert(0, grid)
processed.infotexts.insert(0, processed.infotexts[0])
return processed
processing.process_images_inner = process_images_inner_bypass
@ -699,7 +718,7 @@ class World:
)
with open(self.config_path, 'w+') as config_file:
config_file.write(config.model_dump_json(indent=3))
config_file.write(config.json(indent=3))
logger.debug(f"config saved")
def ping_remotes(self, indiscriminate: bool = False):
@ -749,6 +768,9 @@ class World:
worker.set_state(State.IDLE, expect_cycle=True)
else:
msg = f"worker '{worker.label}' is unreachable"
if worker.response.status_code is not None:
msg += f" <{worker.response.status_code}>"
logger.info(msg)
gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE)