update thin-client mode
parent
622827aab5
commit
64bc137d65
|
|
@ -116,19 +116,14 @@ class DistributedScript(scripts.Script):
|
|||
p.negative_prompts.extend(all_negative_prompts)
|
||||
p.prompts.extend(all_prompts)
|
||||
|
||||
num_local = self.world.p.n_iter * self.world.p.batch_size + opts.return_grid
|
||||
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
|
||||
num_injected = len(pp.images) - self.world.p.batch_size
|
||||
for i, image in enumerate(images):
|
||||
# modules.ui_common -> update_generation_info renders to html below gallery
|
||||
|
||||
# TODO probably shouldn't be here
|
||||
if self.world.thin_client_mode:
|
||||
p.all_negative_prompts = pp.all_negative_prompts
|
||||
|
||||
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
|
||||
job.gallery_map.append(gallery_index) # so we know where to edit infotext
|
||||
pp.images.append(image)
|
||||
logger.debug(f"image {gallery_index + 1}/{self.world.num_gallery()}")
|
||||
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
|
||||
|
||||
def update_gallery(self, pp, p):
|
||||
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
|
||||
|
|
@ -328,13 +323,11 @@ class DistributedScript(scripts.Script):
|
|||
p.batch_size = self.world.master_job().batch_size
|
||||
self.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
# p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
self.runs_since_init += 1
|
||||
return
|
||||
|
||||
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||
if p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
||||
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
||||
return
|
||||
|
||||
is_img2img = getattr(p, 'init_images', False)
|
||||
|
|
|
|||
|
|
@ -609,6 +609,7 @@ class Worker:
|
|||
self.full_url("memory"),
|
||||
timeout=3
|
||||
)
|
||||
self.response = response
|
||||
return response.status_code == 200
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from .worker import Worker, State
|
|||
from modules.call_queue import wrap_queued_call, queue_lock
|
||||
from modules import processing
|
||||
from modules import progress
|
||||
from modules.scripts import PostprocessBatchListArgs
|
||||
from torchvision.transforms import ToPILImage
|
||||
from modules.images import image_grid
|
||||
|
||||
|
||||
class NotBenchmarked(Exception):
|
||||
|
|
@ -562,15 +565,31 @@ class World:
|
|||
# save original process_images_inner for later so we can restore once we're done
|
||||
logger.debug(f"bypassing local generation completely")
|
||||
def process_images_inner_bypass(p) -> processing.Processed:
|
||||
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
|
||||
pp = PostprocessBatchListArgs(images=[])
|
||||
self.p.scripts.postprocess_batch_list(p, pp)
|
||||
|
||||
processed = processing.Processed(p, [], p.seed, info="")
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
processed.all_prompts = p.prompts
|
||||
processed.all_seeds = p.seeds
|
||||
processed.all_subseeds = p.subseeds
|
||||
processed.all_negative_prompts = p.negative_prompts
|
||||
processed.images = pp.images
|
||||
processed.infotexts = [''] * self.num_requested()
|
||||
|
||||
transform = ToPILImage()
|
||||
for i, image in enumerate(processed.images):
|
||||
processed.images[i] = transform(image)
|
||||
|
||||
|
||||
self.p.scripts.postprocess(p, processed)
|
||||
|
||||
# generate grid if enabled
|
||||
if shared.opts.return_grid and len(processed.images) > 1:
|
||||
grid = image_grid(processed.images, len(processed.images))
|
||||
processed.images.insert(0, grid)
|
||||
processed.infotexts.insert(0, processed.infotexts[0])
|
||||
|
||||
return processed
|
||||
processing.process_images_inner = process_images_inner_bypass
|
||||
|
||||
|
|
@ -699,7 +718,7 @@ class World:
|
|||
)
|
||||
|
||||
with open(self.config_path, 'w+') as config_file:
|
||||
config_file.write(config.model_dump_json(indent=3))
|
||||
config_file.write(config.json(indent=3))
|
||||
logger.debug(f"config saved")
|
||||
|
||||
def ping_remotes(self, indiscriminate: bool = False):
|
||||
|
|
@ -749,6 +768,9 @@ class World:
|
|||
worker.set_state(State.IDLE, expect_cycle=True)
|
||||
else:
|
||||
msg = f"worker '{worker.label}' is unreachable"
|
||||
if worker.response.status_code is not None:
|
||||
msg += f" <{worker.response.status_code}>"
|
||||
|
||||
logger.info(msg)
|
||||
gradio.Warning("Distributed: "+msg)
|
||||
worker.set_state(State.UNAVAILABLE)
|
||||
|
|
|
|||
Loading…
Reference in New Issue