update thin-client mode

master^2
papuSpartan 2024-10-25 21:55:49 -05:00
parent 622827aab5
commit 64bc137d65
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 33 additions and 17 deletions

View File

@ -116,19 +116,14 @@ class DistributedScript(scripts.Script):
p.negative_prompts.extend(all_negative_prompts) p.negative_prompts.extend(all_negative_prompts)
p.prompts.extend(all_prompts) p.prompts.extend(all_prompts)
num_local = self.world.p.n_iter * self.world.p.batch_size + opts.return_grid num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
num_injected = len(pp.images) - self.world.p.batch_size num_injected = len(pp.images) - self.world.p.batch_size
for i, image in enumerate(images): for i, image in enumerate(images):
# modules.ui_common -> update_generation_info renders to html below gallery # modules.ui_common -> update_generation_info renders to html below gallery
# TODO probably shouldn't be here
if self.world.thin_client_mode:
p.all_negative_prompts = pp.all_negative_prompts
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
job.gallery_map.append(gallery_index) # so we know where to edit infotext job.gallery_map.append(gallery_index) # so we know where to edit infotext
pp.images.append(image) pp.images.append(image)
logger.debug(f"image {gallery_index + 1}/{self.world.num_gallery()}") logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
def update_gallery(self, pp, p): def update_gallery(self, pp, p):
"""adds all remotely generated images to the image gallery after waiting for all workers to finish""" """adds all remotely generated images to the image gallery after waiting for all workers to finish"""
@ -328,13 +323,11 @@ class DistributedScript(scripts.Script):
p.batch_size = self.world.master_job().batch_size p.batch_size = self.world.master_job().batch_size
self.master_start = time.time() self.master_start = time.time()
# generate images assigned to local machine
# p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
self.runs_since_init += 1 self.runs_since_init += 1
return return
def postprocess_batch_list(self, p, pp, *args, **kwargs): def postprocess_batch_list(self, p, pp, *args, **kwargs):
if p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
return return
is_img2img = getattr(p, 'init_images', False) is_img2img = getattr(p, 'init_images', False)

View File

@ -609,6 +609,7 @@ class Worker:
self.full_url("memory"), self.full_url("memory"),
timeout=3 timeout=3
) )
self.response = response
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:

View File

@ -21,6 +21,9 @@ from .worker import Worker, State
from modules.call_queue import wrap_queued_call, queue_lock from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing from modules import processing
from modules import progress from modules import progress
from modules.scripts import PostprocessBatchListArgs
from torchvision.transforms import ToPILImage
from modules.images import image_grid
class NotBenchmarked(Exception): class NotBenchmarked(Exception):
@ -562,15 +565,31 @@ class World:
# save original process_images_inner for later so we can restore once we're done # save original process_images_inner for later so we can restore once we're done
logger.debug(f"bypassing local generation completely") logger.debug(f"bypassing local generation completely")
def process_images_inner_bypass(p) -> processing.Processed: def process_images_inner_bypass(p) -> processing.Processed:
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
pp = PostprocessBatchListArgs(images=[])
self.p.scripts.postprocess_batch_list(p, pp)
processed = processing.Processed(p, [], p.seed, info="") processed = processing.Processed(p, [], p.seed, info="")
processed.all_prompts = [] processed.all_prompts = p.prompts
processed.all_seeds = [] processed.all_seeds = p.seeds
processed.all_subseeds = [] processed.all_subseeds = p.subseeds
processed.all_negative_prompts = [] processed.all_negative_prompts = p.negative_prompts
processed.infotexts = [] processed.images = pp.images
processed.prompt = None processed.infotexts = [''] * self.num_requested()
transform = ToPILImage()
for i, image in enumerate(processed.images):
processed.images[i] = transform(image)
self.p.scripts.postprocess(p, processed) self.p.scripts.postprocess(p, processed)
# generate grid if enabled
if shared.opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed.images.insert(0, grid)
processed.infotexts.insert(0, processed.infotexts[0])
return processed return processed
processing.process_images_inner = process_images_inner_bypass processing.process_images_inner = process_images_inner_bypass
@ -699,7 +718,7 @@ class World:
) )
with open(self.config_path, 'w+') as config_file: with open(self.config_path, 'w+') as config_file:
config_file.write(config.model_dump_json(indent=3)) config_file.write(config.json(indent=3))
logger.debug(f"config saved") logger.debug(f"config saved")
def ping_remotes(self, indiscriminate: bool = False): def ping_remotes(self, indiscriminate: bool = False):
@ -749,6 +768,9 @@ class World:
worker.set_state(State.IDLE, expect_cycle=True) worker.set_state(State.IDLE, expect_cycle=True)
else: else:
msg = f"worker '{worker.label}' is unreachable" msg = f"worker '{worker.label}' is unreachable"
if worker.response.status_code is not None:
msg += f" <{worker.response.status_code}>"
logger.info(msg) logger.info(msg)
gradio.Warning("Distributed: "+msg) gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE) worker.set_state(State.UNAVAILABLE)