update thin-client mode
parent
622827aab5
commit
64bc137d65
|
|
@ -116,19 +116,14 @@ class DistributedScript(scripts.Script):
|
||||||
p.negative_prompts.extend(all_negative_prompts)
|
p.negative_prompts.extend(all_negative_prompts)
|
||||||
p.prompts.extend(all_prompts)
|
p.prompts.extend(all_prompts)
|
||||||
|
|
||||||
num_local = self.world.p.n_iter * self.world.p.batch_size + opts.return_grid
|
num_local = self.world.p.n_iter * self.world.p.batch_size + (opts.return_grid - self.world.thin_client_mode)
|
||||||
num_injected = len(pp.images) - self.world.p.batch_size
|
num_injected = len(pp.images) - self.world.p.batch_size
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
# modules.ui_common -> update_generation_info renders to html below gallery
|
# modules.ui_common -> update_generation_info renders to html below gallery
|
||||||
|
|
||||||
# TODO probably shouldn't be here
|
|
||||||
if self.world.thin_client_mode:
|
|
||||||
p.all_negative_prompts = pp.all_negative_prompts
|
|
||||||
|
|
||||||
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
|
gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
|
||||||
job.gallery_map.append(gallery_index) # so we know where to edit infotext
|
job.gallery_map.append(gallery_index) # so we know where to edit infotext
|
||||||
pp.images.append(image)
|
pp.images.append(image)
|
||||||
logger.debug(f"image {gallery_index + 1}/{self.world.num_gallery()}")
|
logger.debug(f"image {gallery_index + 1 + self.world.thin_client_mode}/{self.world.num_gallery()}")
|
||||||
|
|
||||||
def update_gallery(self, pp, p):
|
def update_gallery(self, pp, p):
|
||||||
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
|
"""adds all remotely generated images to the image gallery after waiting for all workers to finish"""
|
||||||
|
|
@ -328,13 +323,11 @@ class DistributedScript(scripts.Script):
|
||||||
p.batch_size = self.world.master_job().batch_size
|
p.batch_size = self.world.master_job().batch_size
|
||||||
self.master_start = time.time()
|
self.master_start = time.time()
|
||||||
|
|
||||||
# generate images assigned to local machine
|
|
||||||
# p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
|
||||||
self.runs_since_init += 1
|
self.runs_since_init += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||||
if p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
if not self.world.thin_client_mode and p.n_iter != kwargs['batch_number'] + 1: # skip if not the final batch
|
||||||
return
|
return
|
||||||
|
|
||||||
is_img2img = getattr(p, 'init_images', False)
|
is_img2img = getattr(p, 'init_images', False)
|
||||||
|
|
|
||||||
|
|
@ -609,6 +609,7 @@ class Worker:
|
||||||
self.full_url("memory"),
|
self.full_url("memory"),
|
||||||
timeout=3
|
timeout=3
|
||||||
)
|
)
|
||||||
|
self.response = response
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@ from .worker import Worker, State
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock
|
from modules.call_queue import wrap_queued_call, queue_lock
|
||||||
from modules import processing
|
from modules import processing
|
||||||
from modules import progress
|
from modules import progress
|
||||||
|
from modules.scripts import PostprocessBatchListArgs
|
||||||
|
from torchvision.transforms import ToPILImage
|
||||||
|
from modules.images import image_grid
|
||||||
|
|
||||||
|
|
||||||
class NotBenchmarked(Exception):
|
class NotBenchmarked(Exception):
|
||||||
|
|
@ -562,15 +565,31 @@ class World:
|
||||||
# save original process_images_inner for later so we can restore once we're done
|
# save original process_images_inner for later so we can restore once we're done
|
||||||
logger.debug(f"bypassing local generation completely")
|
logger.debug(f"bypassing local generation completely")
|
||||||
def process_images_inner_bypass(p) -> processing.Processed:
|
def process_images_inner_bypass(p) -> processing.Processed:
|
||||||
|
p.seeds, p.subseeds, p.negative_prompts, p.prompts = [], [], [], []
|
||||||
|
pp = PostprocessBatchListArgs(images=[])
|
||||||
|
self.p.scripts.postprocess_batch_list(p, pp)
|
||||||
|
|
||||||
processed = processing.Processed(p, [], p.seed, info="")
|
processed = processing.Processed(p, [], p.seed, info="")
|
||||||
processed.all_prompts = []
|
processed.all_prompts = p.prompts
|
||||||
processed.all_seeds = []
|
processed.all_seeds = p.seeds
|
||||||
processed.all_subseeds = []
|
processed.all_subseeds = p.subseeds
|
||||||
processed.all_negative_prompts = []
|
processed.all_negative_prompts = p.negative_prompts
|
||||||
processed.infotexts = []
|
processed.images = pp.images
|
||||||
processed.prompt = None
|
processed.infotexts = [''] * self.num_requested()
|
||||||
|
|
||||||
|
transform = ToPILImage()
|
||||||
|
for i, image in enumerate(processed.images):
|
||||||
|
processed.images[i] = transform(image)
|
||||||
|
|
||||||
|
|
||||||
self.p.scripts.postprocess(p, processed)
|
self.p.scripts.postprocess(p, processed)
|
||||||
|
|
||||||
|
# generate grid if enabled
|
||||||
|
if shared.opts.return_grid and len(processed.images) > 1:
|
||||||
|
grid = image_grid(processed.images, len(processed.images))
|
||||||
|
processed.images.insert(0, grid)
|
||||||
|
processed.infotexts.insert(0, processed.infotexts[0])
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
processing.process_images_inner = process_images_inner_bypass
|
processing.process_images_inner = process_images_inner_bypass
|
||||||
|
|
||||||
|
|
@ -699,7 +718,7 @@ class World:
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(self.config_path, 'w+') as config_file:
|
with open(self.config_path, 'w+') as config_file:
|
||||||
config_file.write(config.model_dump_json(indent=3))
|
config_file.write(config.json(indent=3))
|
||||||
logger.debug(f"config saved")
|
logger.debug(f"config saved")
|
||||||
|
|
||||||
def ping_remotes(self, indiscriminate: bool = False):
|
def ping_remotes(self, indiscriminate: bool = False):
|
||||||
|
|
@ -749,6 +768,9 @@ class World:
|
||||||
worker.set_state(State.IDLE, expect_cycle=True)
|
worker.set_state(State.IDLE, expect_cycle=True)
|
||||||
else:
|
else:
|
||||||
msg = f"worker '{worker.label}' is unreachable"
|
msg = f"worker '{worker.label}' is unreachable"
|
||||||
|
if worker.response.status_code is not None:
|
||||||
|
msg += f" <{worker.response.status_code}>"
|
||||||
|
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Warning("Distributed: "+msg)
|
gradio.Warning("Distributed: "+msg)
|
||||||
worker.set_state(State.UNAVAILABLE)
|
worker.set_state(State.UNAVAILABLE)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue