refactor of progress monitoring

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3696/head
Vladimir Mandic 2025-01-05 13:22:19 -05:00
parent 669799bfcc
commit 0114b59470
24 changed files with 177 additions and 84 deletions

View File

@ -22,6 +22,7 @@
"default-case":"off",
"no-await-in-loop":"off",
"no-bitwise":"off",
"no-continue":"off",
"no-confusing-arrow":"off",
"no-console":"off",
"no-empty":"off",

View File

@ -20,6 +20,8 @@
- add explicit detailer steps setting
- **SysInfo**:
- update to collected data and benchmarks
- **Progress**:
- refactored progress monitoring, job updates and live preview
- **Metadata**:
- improved metadata save and restore
- **Fixes**:

View File

@ -20,29 +20,36 @@ function checkPaused(state) {
function setProgress(res) {
const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate', 'control_generate'];
const progress = (res?.progress || 0);
let job = res?.job || '';
job = job.replace('txt2img', 'Generate').replace('img2img', 'Generate');
const perc = res && (progress > 0) ? `${Math.round(100.0 * progress)}%` : '';
let sec = res?.eta || 0;
const progress = res?.progress || 0;
const job = res?.job || '';
let perc = '';
let eta = '';
if (res?.paused) eta = 'Paused';
else if (res?.completed || (progress > 0.99)) eta = 'Finishing';
else if (sec === 0) eta = 'Starting';
if (job === 'VAE') perc = 'Decode';
else {
const min = Math.floor(sec / 60);
sec %= 60;
eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`;
perc = res && (progress > 0) && (progress < 1) ? `${Math.round(100.0 * progress)}% ` : '';
let sec = res?.eta || 0;
if (res?.paused) eta = 'Paused';
else if (res?.completed || (progress > 0.99)) eta = 'Finishing';
else if (sec === 0) eta = 'Start';
else {
const min = Math.floor(sec / 60);
sec %= 60;
eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`;
}
}
document.title = `SD.Next ${perc}`;
for (const elId of elements) {
const el = document.getElementById(elId);
if (el) {
el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate');
const jobLabel = (res ? `${job} ${perc}${eta}` : 'Generate').trim();
el.innerText = jobLabel;
if (!window.waitForUiReady) {
el.style.background = res && (progress > 0)
? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})`
: 'var(--button-primary-background-fill)';
const gradient = perc !== '' ? perc : '100%';
if (jobLabel === 'Generate') el.style.background = 'var(--primary-500)';
else if (jobLabel.endsWith('Decode')) continue;
else if (jobLabel.endsWith('Start') || jobLabel.endsWith('Finishing')) el.style.background = 'var(--primary-800)';
else if (res && progress > 0 && progress < 1) el.style.background = `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${gradient}, var(--neutral-700) ${gradient})`;
else el.style.background = 'var(--primary-500)';
}
}
}

View File

@ -15,8 +15,8 @@ def wrap_queued_call(func):
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
name = func.__name__
def wrap_gradio_gpu_call(func, extra_outputs=None, name=None):
name = name or func.__name__
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":

View File

@ -71,6 +71,7 @@ def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint:
img.already_saved_as = name
size = os.path.getsize(name)
shared.log.debug(f'Save temp: image="{name}" width={img.width} height={img.height} size={size}')
shared.state.image_history += 1
params = ', '.join([f'{k}: {v}' for k, v in img.info.items()])
params = params[12:] if params.startswith('parameters: ') else params
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:

View File

@ -62,6 +62,7 @@ class History():
return -1
def add(self, latent, preview=None, info=None, ops=[]):
shared.state.latent_history += 1
if shared.opts.latent_history == 0:
return
if torch.is_tensor(latent):

View File

@ -29,6 +29,7 @@ def atomically_save_image():
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
while True:
image, filename, extension, params, exifinfo, filename_txt = save_queue.get()
shared.state.image_history += 1
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
file.write(exifinfo)
fn = filename + extension
@ -49,6 +50,7 @@ def atomically_save_image():
shared.log.info(f'Save: text="{filename_txt}" len={len(exifinfo)}')
except Exception as e:
shared.log.warning(f'Save failed: description={filename_txt} {e}')
# actual save
if image_format == 'PNG':
pnginfo_data = PngImagePlugin.PngInfo()
@ -79,6 +81,7 @@ def atomically_save_image():
errors.display(e, 'Image save')
size = os.path.getsize(fn) if os.path.exists(fn) else 0
shared.log.info(f'Save: image="{fn}" type={image_format} width={image.width} height={image.height} size={size}')
if shared.opts.save_log_fn != '' and len(exifinfo) > 0:
fn = os.path.join(paths.data_path, shared.opts.save_log_fn)
if not fn.endswith('.json'):

View File

@ -265,7 +265,7 @@ def create_ui():
auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio])
extract.click(
fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]),
fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[], name='LoRA'),
inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite],
outputs=[status]
)

View File

@ -280,19 +280,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
output_images = []
process_init(p)
if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and not shared.native:
if not shared.native and os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings:
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False)
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
p.scripts.process(p)
ema_scope_context = p.sd_model.ema_scope if not shared.native else nullcontext
shared.state.job_count = p.n_iter
if not shared.native:
shared.state.job_count = p.n_iter
with devices.inference_context(), ema_scope_context():
t0 = time.time()
if not hasattr(p, 'skip_init'):
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
debug(f'Processing inner: args={vars(p)}')
for n in range(p.n_iter):
# if hasattr(p, 'skip_processing'):
# continue
pag.apply(p)
debug(f'Processing inner: iteration={n+1}/{p.n_iter}')
p.iteration = n

View File

@ -15,6 +15,7 @@ from modules.api import helpers
debug_enabled = os.environ.get('SD_DIFFUSERS_DEBUG', None)
debug_log = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
disable_pbar = os.environ.get('SD_DISABLE_PBAR', None) is not None
def task_specific_kwargs(p, model):
@ -107,7 +108,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
apply_circular(p.tiling, model)
if hasattr(model, "set_progress_bar_config"):
model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba')
model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba', disable=disable_pbar)
args = {}
has_vae = hasattr(model, 'vae') or (hasattr(model, 'pipe') and hasattr(model.pipe, 'vae'))
if hasattr(model, 'pipe') and not hasattr(model, 'no_recurse'): # recurse

View File

@ -56,8 +56,9 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
latents = kwargs.get('latents', None)
if debug:
debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}')
order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1
shared.state.sampling_step = step // order
shared.state.step()
# order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1
# shared.state.sampling_step = step // order
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
if shared.state.paused:

View File

@ -581,7 +581,7 @@ class StableDiffusionProcessingControl(StableDiffusionProcessingImg2Img):
else:
self.hr_upscale_to_x, self.hr_upscale_to_y = self.hr_resize_x, self.hr_resize_y
# hypertile_set(self, hr=True)
shared.state.job_count = 2 * self.n_iter
# shared.state.job_count = 2 * self.n_iter
shared.log.debug(f'Control hires: upscaler="{self.hr_upscaler}" scale={scale} fixed={not use_scale} size={self.hr_upscale_to_x}x{self.hr_upscale_to_y}')

View File

@ -6,7 +6,7 @@ import torch
import torchvision.transforms.functional as TF
from PIL import Image
from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats, extra_networks
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled, get_job_name
from modules.processing_args import set_pipeline_args
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
from modules.lora import networks
@ -53,8 +53,9 @@ def restore_state(p: processing.StableDiffusionProcessing):
def process_base(p: processing.StableDiffusionProcessing):
use_refiner_start = is_txt2img() and is_refiner_enabled(p) and not p.is_hr_pass and p.refiner_start > 0 and p.refiner_start < 1
use_denoise_start = not is_txt2img() and p.refiner_start > 0 and p.refiner_start < 1
txt2img = is_txt2img()
use_refiner_start = txt2img and is_refiner_enabled(p) and not p.is_hr_pass and p.refiner_start > 0 and p.refiner_start < 1
use_denoise_start = not txt2img and p.refiner_start > 0 and p.refiner_start < 1
shared.sd_model = update_pipeline(shared.sd_model, p)
update_sampler(p, shared.sd_model)
@ -76,7 +77,8 @@ def process_base(p: processing.StableDiffusionProcessing):
clip_skip=p.clip_skip,
desc='Base',
)
shared.state.sampling_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1)
if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1:
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta
output = None
@ -172,7 +174,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
p.ops.append('upscale')
if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'):
save_intermediate(p, latents=output.images, suffix="-before-hires")
shared.state.job = 'Upscale'
shared.state.update('Upscale', 0, 1)
output.images = resize_hires(p, latents=output.images)
sd_hijack_hypertile.hypertile_set(p, hr=True)
@ -190,7 +192,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
shared.log.warning('HiRes skip: denoising=0')
p.hr_force = False
if p.hr_force:
shared.state.job_count = 2 * p.n_iter
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
if 'Upscale' in shared.sd_model.__class__.__name__ or 'Flux' in shared.sd_model.__class__.__name__ or 'Kandinsky' in shared.sd_model.__class__.__name__:
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
@ -217,8 +218,8 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
strength=strength,
desc='Hires',
)
shared.state.job = 'HiRes'
shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None)
hires_steps = hires_args.get('prior_num_inference_steps', None) or p.hr_second_pass_steps or hires_args.get('num_inference_steps', None)
shared.state.update(get_job_name(p, shared.sd_model), hires_steps, 1)
try:
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(shared.sd_model, devices.device)
@ -255,8 +256,6 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
# optional refiner pass or decode
if is_refiner_enabled(p):
prev_job = shared.state.job
shared.state.job = 'Refine'
shared.state.job_count +=1
if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'):
save_intermediate(p, latents=output.images, suffix="-before-refiner")
if shared.opts.diffusers_move_base:
@ -306,7 +305,8 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
prompt_attention='fixed',
desc='Refiner',
)
shared.state.sampling_steps = refiner_args.get('prior_num_inference_steps', None) or p.steps or refiner_args.get('num_inference_steps', None)
refiner_steps = refiner_args.get('prior_num_inference_steps', None) or p.steps or refiner_args.get('num_inference_steps', None)
shared.state.update(get_job_name(p, shared.sd_refiner), refiner_steps, 1)
try:
if 'requires_aesthetics_score' in shared.sd_refiner.config: # sdxl-model needs false and sdxl-refiner needs true
shared.sd_refiner.register_to_config(requires_aesthetics_score = getattr(shared.sd_refiner, 'tokenizer', None) is None)

View File

@ -584,3 +584,26 @@ def update_sampler(p, sd_model, second_pass=False):
sampler_options.append('low order')
if len(sampler_options) > 0:
p.extra_generation_params['Sampler options'] = '/'.join(sampler_options)
def get_job_name(p, model):
if hasattr(model, 'pipe'):
model = model.pipe
if hasattr(p, 'xyz'):
return 'Ignore' # xyz grid handles its own jobs
if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE:
return 'Text'
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE:
if p.is_refiner_pass:
return 'Refiner'
elif p.is_hr_pass:
return 'Hires'
else:
return 'Image'
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING:
if p.detailer:
return 'Detailer'
else:
return 'Inpaint'
else:
return 'Unknown'

View File

@ -64,23 +64,16 @@ def progressapi(req: ProgressRequest):
queued = req.id_task in pending_tasks
completed = req.id_task in finished_tasks
paused = shared.state.paused
shared.state.job_count = max(shared.state.frame_count, shared.state.job_count, shared.state.job_no)
batch_x = max(shared.state.job_no, 0)
batch_y = max(shared.state.job_count, 1)
step_x = max(shared.state.sampling_step, 0)
step_y = max(shared.state.sampling_steps, 1)
current = step_y * batch_x + step_x
total = step_y * batch_y
while total < current:
total += step_y
progress = min(1, abs(current / total) if total > 0 else 0)
step = max(shared.state.sampling_step, 0)
steps = max(shared.state.sampling_steps, 1)
progress = round(min(1, abs(step / steps) if steps > 0 else 0), 2)
elapsed = time.time() - shared.state.time_start if shared.state.time_start is not None else 0
predicted = elapsed / progress if progress > 0 else None
eta = predicted - elapsed if predicted is not None else None
id_live_preview = req.id_live_preview
live_preview = None
updated = shared.state.set_current_image()
debug_log(f'Preview: job={shared.state.job} active={active} progress={current}/{total} step={shared.state.current_image_sampling_step}/{step_x}/{step_y} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}')
debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}')
if not active:
return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, debug=debug, textinfo="Queued..." if queued else "Waiting...")
if shared.opts.live_previews_enable and (shared.state.id_live_preview != id_live_preview) and (shared.state.current_image is not None):

View File

@ -1,10 +1,18 @@
import os
import sys
import time
import datetime
from modules.errors import log, display
debug_output = os.environ.get('SD_STATE_DEBUG', None)
class State:
job_history = []
task_history = []
image_history = 0
latent_history = 0
skipped = False
interrupted = False
paused = False
@ -14,7 +22,7 @@ class State:
frame_count = 0
total_jobs = 0
job_timestamp = '0'
sampling_step = 0
_sampling_step = 0
sampling_steps = 0
current_latent = None
current_noise_pred = None
@ -32,29 +40,48 @@ class State:
need_restart = False
server_start = time.time()
oom = False
debug_output = os.environ.get('SD_STATE_DEBUG', None)
def __str__(self) -> str:
return f'State: job={self.job} {self.job_no}/{self.job_count} step={self.sampling_step}/{self.sampling_steps} skipped={self.skipped} interrupted={self.interrupted} paused={self.paused} info={self.textinfo}'
status = ' '
status += 'skipped ' if self.skipped else ''
status += 'interrupted ' if self.interrupted else ''
status += 'paused ' if self.paused else ''
status += 'restart ' if self.need_restart else ''
status += 'oom ' if self.oom else ''
status += 'api ' if self.api else ''
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}' # pylint: disable=protected-access
return f'State: ts={self.job_timestamp} job={self.job} jobs={self.job_no+1}/{self.job_count}/{self.total_jobs} step={self.sampling_step}/{self.sampling_steps} preview={self.preview_job}/{self.id_live_preview}/{self.current_image_sampling_step} status="{status.strip()}" fn={fn}'
@property
def sampling_step(self):
return self._sampling_step
@sampling_step.setter
def sampling_step(self, value):
self._sampling_step = value
if debug_output:
log.trace(f'State step: {self}')
def skip(self):
log.debug('Requested skip')
log.debug('State: skip requested')
self.skipped = True
def interrupt(self):
log.debug('Requested interrupt')
log.debug('State: interrupt requested')
self.interrupted = True
def pause(self):
self.paused = not self.paused
log.debug(f'Requested {"pause" if self.paused else "continue"}')
log.debug(f'State: {"pause" if self.paused else "continue"} requested')
def nextjob(self):
import modules.devices
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
# self.sampling_step = 0
self.current_image_sampling_step = 0
if debug_output:
log.trace(f'State next: {self}')
modules.devices.torch_gc()
def dict(self):
@ -104,6 +131,7 @@ class State:
def begin(self, title="", api=None):
import modules.devices
self.job_history.append(title)
self.total_jobs += 1
self.current_image = None
self.current_image_sampling_step = 0
@ -115,19 +143,20 @@ class State:
self.interrupted = False
self.preview_job = -1
self.job = title
self.job_count = -1
self.frame_count = -1
self.job_count = 0
self.frame_count = 0
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.paused = False
self.sampling_step = 0
self._sampling_step = 0
self.sampling_steps = 0
self.skipped = False
self.textinfo = None
self.prediction_type = "epsilon"
self.api = api or self.api
self.time_start = time.time()
if self.debug_output:
log.debug(f'State begin: {self.job}')
if debug_output:
log.trace(f'State begin: {self}')
modules.devices.torch_gc()
def end(self, api=None):
@ -136,6 +165,8 @@ class State:
# fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
# log.debug(f'Access state.end: {fn}') # pylint: disable=protected-access
self.time_start = time.time()
if debug_output:
log.trace(f'State end: {self}')
self.job = ""
self.job_count = 0
self.job_no = 0
@ -147,6 +178,24 @@ class State:
self.api = api or self.api
modules.devices.torch_gc()
def step(self, step:int=1):
self.sampling_step += step
def update(self, job:str, steps:int=0, jobs:int=0):
self.task_history.append(job)
# self._sampling_step = 0
if job == 'Ignore':
return
elif job == 'Grid':
self.sampling_steps = steps
self.job_count = jobs
else:
self.sampling_steps += steps * jobs
self.job_count += jobs
self.job = job
if debug_output:
log.trace(f'State update: {self} steps={steps} jobs={jobs}')
def set_current_image(self):
if self.job == 'VAE' or self.job == 'Upscale': # avoid generating preview while vae is running
return False

View File

@ -193,7 +193,7 @@ def create_ui():
override_settings,
]
img2img_dict = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', ''], name='Image'),
_js="submit_img2img",
inputs= img2img_args + img2img_script_inputs,
outputs=[

View File

@ -290,7 +290,7 @@ def create_ui():
beta_apply_preset.click(fn=load_presets, inputs=[beta_preset, beta_preset_lambda], outputs=[beta_base, beta_in_blocks, beta_mid_block, beta_out_blocks, tabs])
modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)], name='Models'),
_js='modelmerger',
inputs=[
dummy_component,

View File

@ -129,7 +129,7 @@ def create_ui():
)
submit.click(
_js="submit_postprocessing",
fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, '']),
fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, ''], name='Postprocess'),
inputs=[
tab_index,
extras_image,

View File

@ -77,7 +77,7 @@ def create_ui():
override_settings,
]
txt2img_dict = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', ''], name='Text'),
_js="submit_txt2img",
inputs=txt2img_args + txt2img_script_inputs,
outputs=[

View File

@ -253,6 +253,7 @@ class Script(scripts.Script):
ys = fix_axis_seeds(y_opt, ys)
zs = fix_axis_seeds(z_opt, zs)
total_jobs = len(xs) * len(ys) * len(zs)
if x_opt.label == 'Steps':
total_steps = sum(xs) * len(ys) * len(zs)
elif y_opt.label == 'Steps':
@ -260,7 +261,7 @@ class Script(scripts.Script):
elif z_opt.label == 'Steps':
total_steps = sum(zs) * len(xs) * len(ys)
else:
total_steps = p.steps * len(xs) * len(ys) * len(zs)
total_steps = p.steps * total_jobs
if isinstance(p, processing.StableDiffusionProcessingTxt2Img) and p.enable_hr:
if x_opt.label == "Hires steps":
total_steps += sum(xs) * len(ys) * len(zs)
@ -269,10 +270,12 @@ class Script(scripts.Script):
elif z_opt.label == "Hires steps":
total_steps += sum(zs) * len(xs) * len(ys)
elif p.hr_second_pass_steps:
total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
total_steps += p.hr_second_pass_steps * total_jobs
else:
total_steps *= 2
total_steps *= p.n_iter
shared.state.update('Grid', total_steps, total_jobs * p.n_iter)
image_cell_count = p.n_iter * p.batch_size
shared.log.info(f"XYZ grid: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} shape={len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}")
AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])

View File

@ -10,7 +10,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
z_texts = [[images.GridAnnotation(z)] for z in z_labels]
list_size = (len(xs) * len(ys) * len(zs))
processed_result = None
shared.state.job_count = list_size * p.n_iter
t0 = time.time()
i = 0
@ -22,7 +22,6 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
def index(ix, iy, iz):
return ix + iy * len(xs) + iz * len(xs) * len(ys)
shared.state.job = 'Grid'
p0 = time.time()
processed: processing.Processed = cell(x, y, z, ix, iy, iz)
p1 = time.time()
@ -63,7 +62,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
cell_mode = processed_result.images[0].mode
cell_size = processed_result.images[0].size
processed_result.images[idx] = Image.new(cell_mode, cell_size)
return
shared.state.nextjob()
if first_axes_processed == 'x':
for ix, x in enumerate(xs):
@ -129,5 +128,6 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
processed_result.infotexts.insert(0, processed_result.infotexts[0])
t2 = time.time()
shared.log.info(f'XYZ grid complete: images={list_size} size={grid.size if grid is not None else None} time={t1-t0:.2f} save={t2-t1:.2f}')
shared.log.info(f'XYZ grid complete: images={list_size} results={len(processed_result.images)}size={grid.size if grid is not None else None} time={t1-t0:.2f} save={t2-t1:.2f}')
p.skip_processing = True
return processed_result

View File

@ -18,7 +18,7 @@ import modules.ui_symbols as symbols
active = False
cache = None
xyz_results_cache = None
debug = shared.log.trace if os.environ.get('SD_XYZ_DEBUG', None) is not None else lambda *args, **kwargs: None
@ -188,8 +188,8 @@ class Script(scripts.Script):
include_time, include_text, margin_size,
create_video, video_type, video_duration, video_loop, video_pad, video_interpolate,
): # pylint: disable=W0221
global active, cache # pylint: disable=W0603
cache = None
global active, xyz_results_cache # pylint: disable=W0603
xyz_results_cache = None
if not enabled or active:
return
active = True
@ -266,6 +266,7 @@ class Script(scripts.Script):
ys = fix_axis_seeds(y_opt, ys)
zs = fix_axis_seeds(z_opt, zs)
total_jobs = len(xs) * len(ys) * len(zs)
if x_opt.label == 'Steps':
total_steps = sum(xs) * len(ys) * len(zs)
elif y_opt.label == 'Steps':
@ -273,8 +274,8 @@ class Script(scripts.Script):
elif z_opt.label == 'Steps':
total_steps = sum(zs) * len(xs) * len(ys)
else:
total_steps = p.steps * len(xs) * len(ys) * len(zs)
if isinstance(p, processing.StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps = p.steps * total_jobs
if p.enable_hr:
if x_opt.label == "Hires steps":
total_steps += sum(xs) * len(ys) * len(zs)
elif y_opt.label == "Hires steps":
@ -282,10 +283,16 @@ class Script(scripts.Script):
elif z_opt.label == "Hires steps":
total_steps += sum(zs) * len(xs) * len(ys)
elif p.hr_second_pass_steps:
total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
total_steps += p.hr_second_pass_steps * total_jobs
else:
total_steps *= 2
if p.detailer:
total_steps += shared.opts.detailer_steps * total_jobs
total_steps *= p.n_iter
total_jobs *= p.n_iter
shared.state.update('Grid', total_steps, total_jobs)
image_cell_count = p.n_iter * p.batch_size
shared.log.info(f"XYZ grid start: images={len(xs)*len(ys)*len(zs)*image_cell_count} grid={len(zs)} shape={len(xs)}x{len(ys)} cells={len(zs)} steps={total_steps}")
AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
@ -360,7 +367,7 @@ class Script(scripts.Script):
return processed
with SharedSettingsStackHelper():
processed = draw_xyz_grid(
processed: processing.Processed = draw_xyz_grid(
p,
xs=xs,
ys=ys,
@ -418,19 +425,15 @@ class Script(scripts.Script):
p.do_not_save_samples = True
p.disable_extra_networks = True
active = False
cache = processed
xyz_results_cache = processed
return processed
def process_images(self, p, *args): # pylint: disable=W0221, W0613
if hasattr(cache, 'used'):
cache.images.clear()
cache.used = False
elif cache is not None and len(cache.images) > 0:
cache.used = True
if xyz_results_cache is not None and len(xyz_results_cache.images) > 0:
p.restore_faces = False
p.detailer = False
p.color_corrections = None
p.scripts = None
return cache
# p.scripts = None
return xyz_results_cache
return None

View File

@ -1,6 +1,7 @@
import io
import os
import sys
import time
import glob
import signal
import asyncio
@ -156,6 +157,7 @@ def initialize():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(_sig, _frame):
log.trace(f'State history: uptime={round(time.time() - shared.state.server_start)} jobs={len(shared.state.job_history)} tasks={len(shared.state.task_history)} latents={shared.state.latent_history} images={shared.state.image_history}')
log.info('Exiting')
try:
for f in glob.glob("*.lock"):
@ -176,9 +178,9 @@ def load_model():
thread_model.start()
thread_refiner = Thread(target=lambda: shared.sd_refiner)
thread_refiner.start()
shared.state.end()
thread_model.join()
thread_refiner.join()
shared.state.end()
timer.startup.record("checkpoint")
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False)
shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False)