cleanup/refactor state history

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4204/head
Vladimir Mandic 2025-09-12 16:12:43 -04:00
parent a8b850adf4
commit 175e9cbe29
41 changed files with 172 additions and 171 deletions

View File

@ -1,8 +1,7 @@
from typing import Optional, List, Union from typing import Optional, List
from threading import Lock from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from modules import errors, shared, processing_helpers from modules import errors, shared, processing_helpers
from modules.processing import StableDiffusionProcessingControl
from modules.api import models, helpers from modules.api import models, helpers
from modules.control import run from modules.control import run
@ -180,7 +179,7 @@ class APIControl():
# run # run
with self.queue_lock: with self.queue_lock:
shared.state.begin('API-CTL', api=True) jobid = shared.state.begin('API-CTL', api=True)
output_images = [] output_images = []
output_processed = [] output_processed = []
output_info = '' output_info = ''
@ -198,7 +197,7 @@ class APIControl():
output_info += item output_info += item
else: else:
pass pass
shared.state.end(api=False) shared.state.end(jobid)
# return # return
b64images = list(map(helpers.encode_pil_to_base64, output_images)) if send_images else [] b64images = list(map(helpers.encode_pil_to_base64, output_images)) if send_images else []

View File

@ -109,7 +109,7 @@ class APIGenerate():
p.outpath_samples = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples p.outpath_samples = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples
for key, value in getattr(txt2imgreq, "extra", {}).items(): for key, value in getattr(txt2imgreq, "extra", {}).items():
setattr(p, key, value) setattr(p, key, value)
shared.state.begin('API TXT', api=True) jobid = shared.state.begin('API-TXT', api=True)
script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
if selectable_scripts is not None: if selectable_scripts is not None:
@ -118,7 +118,7 @@ class APIGenerate():
processed = process_images(p) processed = process_images(p)
processed = scripts_manager.scripts_txt2img.after(p, processed, *script_args) processed = scripts_manager.scripts_txt2img.after(p, processed, *script_args)
p.close() p.close()
shared.state.end(api=False) shared.state.end(jobid)
if processed is None or processed.images is None or len(processed.images) == 0: if processed is None or processed.images is None or len(processed.images) == 0:
b64images = [] b64images = []
else: else:
@ -161,7 +161,7 @@ class APIGenerate():
p.outpath_samples = shared.opts.outdir_img2img_samples p.outpath_samples = shared.opts.outdir_img2img_samples
for key, value in getattr(img2imgreq, "extra", {}).items(): for key, value in getattr(img2imgreq, "extra", {}).items():
setattr(p, key, value) setattr(p, key, value)
shared.state.begin('API-IMG', api=True) jobid = shared.state.begin('API-IMG', api=True)
script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
if selectable_scripts is not None: if selectable_scripts is not None:
@ -170,7 +170,7 @@ class APIGenerate():
processed = process_images(p) processed = process_images(p)
processed = scripts_manager.scripts_img2img.after(p, processed, *script_args) processed = scripts_manager.scripts_img2img.after(p, processed, *script_args)
p.close() p.close()
shared.state.end(api=False) shared.state.end(jobid)
if processed is None or processed.images is None or len(processed.images) == 0: if processed is None or processed.images is None or len(processed.images) == 0:
b64images = [] b64images = []
else: else:

View File

@ -339,9 +339,8 @@ class ResProgress(BaseModel):
class ResHistory(BaseModel): class ResHistory(BaseModel):
id: Union[int, str, None] = Field(title="ID", description="Task ID") id: Union[int, str, None] = Field(title="ID", description="Task ID")
job: str = Field(title="Job", description="Job name") job: str = Field(title="Job", description="Job name")
op: str = Field(title="Operation", description="Operation name") op: str = Field(title="Operation", description="Job state")
start: Union[float, None] = Field(title="Start", description="Start time") timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp")
end: Union[float, None] = Field(title="End", description="End time")
outputs: List[str] = Field(title="Outputs", description="List of filenames") outputs: List[str] = Field(title="Outputs", description="List of filenames")
class ResStatus(BaseModel): class ResStatus(BaseModel):

View File

@ -77,10 +77,10 @@ class APIProcess():
for k, v in req.params.items(): for k, v in req.params.items():
if k not in processors.config[processor.processor_id]['params']: if k not in processors.config[processor.processor_id]['params']:
return JSONResponse(status_code=400, content={"error": f"Processor invalid parameter: id={req.model} {k}={v}"}) return JSONResponse(status_code=400, content={"error": f"Processor invalid parameter: id={req.model} {k}={v}"})
shared.state.begin('API-PRE', api=True) jobid = shared.state.begin('API-PRE', api=True)
processed = processor(image, local_config=req.params) processed = processor(image, local_config=req.params)
image = encode_pil_to_base64(processed) image = encode_pil_to_base64(processed)
shared.state.end(api=False) shared.state.end(jobid)
return ResPreprocess(model=processor.processor_id, image=image) return ResPreprocess(model=processor.processor_id, image=image)
def get_mask(self): def get_mask(self):
@ -103,10 +103,10 @@ class APIProcess():
return JSONResponse(status_code=400, content={"error": f"Mask invalid parameter: {k}={v}"}) return JSONResponse(status_code=400, content={"error": f"Mask invalid parameter: {k}={v}"})
else: else:
setattr(masking.opts, k, v) setattr(masking.opts, k, v)
shared.state.begin('API-MASK', api=True) jobid = shared.state.begin('API-MASK', api=True)
with self.queue_lock: with self.queue_lock:
processed = masking.run_mask(input_image=image, input_mask=mask, return_type=req.type) processed = masking.run_mask(input_image=image, input_mask=mask, return_type=req.type)
shared.state.end(api=False) shared.state.end(jobid)
if processed is None: if processed is None:
return JSONResponse(status_code=400, content={"error": "Mask is none"}) return JSONResponse(status_code=400, content={"error": "Mask is none"})
image = encode_pil_to_base64(processed) image = encode_pil_to_base64(processed)
@ -115,7 +115,7 @@ class APIProcess():
def post_detect(self, req: ReqFace): def post_detect(self, req: ReqFace):
from modules.shared import yolo # pylint: disable=no-name-in-module from modules.shared import yolo # pylint: disable=no-name-in-module
image = decode_base64_to_image(req.image) image = decode_base64_to_image(req.image)
shared.state.begin('API-FACE', api=True) jobid = shared.state.begin('API-FACE', api=True)
images = [] images = []
scores = [] scores = []
classes = [] classes = []
@ -129,7 +129,7 @@ class APIProcess():
classes.append(item.cls) classes.append(item.cls)
labels.append(item.label) labels.append(item.label)
boxes.append(item.box) boxes.append(item.box)
shared.state.end(api=False) shared.state.end(jobid)
return ResFace(classes=classes, labels=labels, scores=scores, boxes=boxes, images=images) return ResFace(classes=classes, labels=labels, scores=scores, boxes=boxes, images=images)
def post_prompt_enhance(self, req: models.ReqPromptEnhance): def post_prompt_enhance(self, req: models.ReqPromptEnhance):

View File

@ -50,7 +50,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False, name=None):
task_id = args[0] task_id = args[0]
else: else:
task_id = 0 task_id = 0
shared.state.begin(job_name, task_id=task_id) jobid = shared.state.begin(job_name, task_id=task_id)
try: try:
if shared.cmd_opts.profile: if shared.cmd_opts.profile:
pr = cProfile.Profile() pr = cProfile.Profile()
@ -70,7 +70,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False, name=None):
if extra_outputs_array is None: if extra_outputs_array is None:
extra_outputs_array = [None, ''] extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"] res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
shared.state.end() shared.state.end(jobid)
if not add_stats: if not add_stats:
return tuple(res) return tuple(res)
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t

View File

@ -62,7 +62,7 @@ def download_civit_preview(model_path: str, preview_url: str):
block_size = 16384 # 16KB blocks block_size = 16384 # 16KB blocks
written = 0 written = 0
img = None img = None
shared.state.begin('CivitAI') jobid = shared.state.begin('Download')
if pbar is None: if pbar is None:
pbar = p.Progress(p.TextColumn('[cyan]Download'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), p.TextColumn('[yellow]{task.description}'), console=shared.console) pbar = p.Progress(p.TextColumn('[cyan]Download'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), p.TextColumn('[yellow]{task.description}'), console=shared.console)
try: try:
@ -82,8 +82,9 @@ def download_civit_preview(model_path: str, preview_url: str):
img = Image.open(preview_file) img = Image.open(preview_file)
except Exception as e: except Exception as e:
shared.log.error(f'CivitAI download error: url={preview_url} file="{preview_file}" written={written} {e}') shared.log.error(f'CivitAI download error: url={preview_url} file="{preview_file}" written={written} {e}')
shared.state.end(jobid)
return 500, '', str(e) return 500, '', str(e)
shared.state.end() shared.state.end(jobid)
if img is None: if img is None:
return 500, '', 'image is none' return 500, '', 'image is none'
shared.log.info(f'CivitAI download: url={preview_url} file="{preview_file}" size={total_size} image={img.size}') shared.log.info(f'CivitAI download: url={preview_url} file="{preview_file}" size={total_size} image={img.size}')
@ -138,7 +139,7 @@ def download_civit_model_thread(model_name: str, model_url: str, model_path: str
res += f' size={round((starting_pos + total_size)/1024/1024, 2)}Mb' res += f' size={round((starting_pos + total_size)/1024/1024, 2)}Mb'
shared.log.info(res) shared.log.info(res)
shared.state.begin('CivitAI') jobid = shared.state.begin('Download')
block_size = 16384 # 16KB blocks block_size = 16384 # 16KB blocks
written = starting_pos written = starting_pos
global pbar # pylint: disable=global-statement global pbar # pylint: disable=global-statement
@ -171,7 +172,7 @@ def download_civit_model_thread(model_name: str, model_url: str, model_path: str
elif os.path.exists(temp_file): elif os.path.exists(temp_file):
shared.log.debug(f'Model download complete: temp="{temp_file}" path="{model_file}"') shared.log.debug(f'Model download complete: temp="{temp_file}" path="{model_file}"')
os.rename(temp_file, model_file) os.rename(temp_file, model_file)
shared.state.end() shared.state.end(jobid)
if os.path.exists(model_file): if os.path.exists(model_file):
return model_file return model_file
else: else:

View File

@ -30,12 +30,12 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
shared.state.begin('Merge') jobid = shared.state.begin('Merge')
t0 = time.time() t0 = time.time()
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message
shared.state.end() shared.state.end(jobid)
return [*[gr.update() for _ in range(4)], message] return [*[gr.update() for _ in range(4)], message]
kwargs["models"] = { kwargs["models"] = {
@ -177,7 +177,7 @@ def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
if created_model: if created_model:
created_model.calculate_shorthash() created_model.calculate_shorthash()
devices.torch_gc(force=True, reason='merge') devices.torch_gc(force=True, reason='merge')
shared.state.end() shared.state.end(jobid)
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model saved to {output_modelname}"] return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model saved to {output_modelname}"]
@ -209,7 +209,7 @@ def run_model_modules(model_type:str, model_name:str, custom_name:str,
yield msg("input model not found", err=True) yield msg("input model not found", err=True)
return return
fn = checkpoint_info.filename fn = checkpoint_info.filename
shared.state.begin('Merge') jobid = shared.state.begin('Merge')
yield msg("modules merge starting") yield msg("modules merge starting")
yield msg("unload current model") yield msg("unload current model")
sd_models.unload_model_weights(op='model') sd_models.unload_model_weights(op='model')
@ -257,4 +257,4 @@ def run_model_modules(model_type:str, model_name:str, custom_name:str,
sd_models.set_diffuser_options(shared.sd_model, offload=False) sd_models.set_diffuser_options(shared.sd_model, offload=False)
sd_models.set_diffuser_offload(shared.sd_model) sd_models.set_diffuser_offload(shared.sd_model)
yield msg("pipeline loaded") yield msg("pipeline loaded")
shared.state.end() shared.state.end(jobid)

View File

@ -1,4 +1,3 @@
import copy
import hashlib import hashlib
import os.path import os.path
from rich import progress, errors from rich import progress, errors
@ -75,8 +74,7 @@ def sha256(filename, title, use_addnet_hash=False):
return None return None
if not os.path.isfile(filename): if not os.path.isfile(filename):
return None return None
orig_state = copy.deepcopy(shared.state) jobid = shared.state.begin("Hash")
shared.state.begin("Hash")
if use_addnet_hash: if use_addnet_hash:
if progress_ok: if progress_ok:
try: try:
@ -94,8 +92,7 @@ def sha256(filename, title, use_addnet_hash=False):
"mtime": os.path.getmtime(filename), "mtime": os.path.getmtime(filename),
"sha256": sha256_value "sha256": sha256_value
} }
shared.state.end() shared.state.end(jobid)
shared.state = orig_state
dump_cache() dump_cache()
return sha256_value return sha256_value

View File

@ -46,6 +46,7 @@ def atomically_save_image():
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
while True: while True:
image, filename, extension, params, exifinfo, filename_txt, is_grid = save_queue.get() image, filename, extension, params, exifinfo, filename_txt, is_grid = save_queue.get()
jobid = shared.state.begin('Save')
shared.state.image_history += 1 shared.state.image_history += 1
if len(exifinfo) > 2: if len(exifinfo) > 2:
with open(paths.params_path, "w", encoding="utf8") as file: with open(paths.params_path, "w", encoding="utf8") as file:
@ -126,6 +127,8 @@ def atomically_save_image():
entries.append(entry) entries.append(entry)
shared.writefile(entries, fn, mode='w', silent=True) shared.writefile(entries, fn, mode='w', silent=True)
shared.log.info(f'Save: json="{fn}" records={len(entries)}') shared.log.info(f'Save: json="{fn}" records={len(entries)}')
shared.state.outputs(filename)
shared.state.end(jobid)
save_queue.task_done() save_queue.task_done()
@ -206,7 +209,6 @@ def save_image(image,
exifinfo += params.pnginfo.get(pnginfo_section_name, '') exifinfo += params.pnginfo.get(pnginfo_section_name, '')
filename, extension = os.path.splitext(params.filename) filename, extension = os.path.splitext(params.filename)
filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None
shared.state.outputs(params.filename)
save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt, grid)) # actual save is executed in a thread that polls data from queue save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt, grid)) # actual save is executed in a thread that polls data from queue
save_queue.join() save_queue.join()
if not hasattr(params.image, 'already_saved_as'): if not hasattr(params.image, 'already_saved_as'):

View File

@ -135,7 +135,7 @@ def interrogate(image, mode, caption=None):
def interrogate_image(image, clip_model, blip_model, mode): def interrogate_image(image, clip_model, blip_model, mode):
shared.state.begin('Interrogate') jobid = shared.state.begin('Interrogate')
try: try:
if shared.sd_loaded: if shared.sd_loaded:
from modules.sd_models import apply_balanced_offload # prevent circular import from modules.sd_models import apply_balanced_offload # prevent circular import
@ -148,7 +148,7 @@ def interrogate_image(image, clip_model, blip_model, mode):
prompt = f"Exception {type(e)}" prompt = f"Exception {type(e)}"
shared.log.error(f'Interrogate: {e}') shared.log.error(f'Interrogate: {e}')
errors.display(e, 'Interrogate') errors.display(e, 'Interrogate')
shared.state.end() shared.state.end(jobid)
return prompt return prompt
@ -164,7 +164,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
if len(files) == 0: if len(files) == 0:
shared.log.warning('Interrogate batch: type=clip no images') shared.log.warning('Interrogate batch: type=clip no images')
return '' return ''
shared.state.begin('Interrogate batch') jobid = shared.state.begin('Interrogate batch')
prompts = [] prompts = []
load_interrogator(clip_model, blip_model) load_interrogator(clip_model, blip_model)
@ -191,7 +191,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
writer.close() writer.close()
ci.config.quiet = False ci.config.quiet = False
unload_clip_model() unload_clip_model()
shared.state.end() shared.state.end(jobid)
return '\n\n'.join(prompts) return '\n\n'.join(prompts)

View File

@ -605,8 +605,7 @@ def sa2(question: str, image: Image.Image, repo: str = None):
def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False): def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False):
global quant_args # pylint: disable=global-statement global quant_args # pylint: disable=global-statement
if not quiet: jobid = shared.state.begin('Interrogate')
shared.state.begin('Interrogate')
t0 = time.time() t0 = time.time()
quant_args = model_quant.create_config(module='LLM') quant_args = model_quant.create_config(module='LLM')
model_name = model_name or shared.opts.interrogate_vlm_model model_name = model_name or shared.opts.interrogate_vlm_model
@ -691,7 +690,7 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
t1 = time.time() t1 = time.time()
if not quiet: if not quiet:
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}') shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
shared.state.end() shared.state.end(jobid)
return answer return answer
@ -725,7 +724,7 @@ def batch(model_name, system_prompt, batch_files, batch_folder, batch_str, quest
if len(files) == 0: if len(files) == 0:
shared.log.warning('Interrogate batch: type=vlm no images') shared.log.warning('Interrogate batch: type=vlm no images')
return '' return ''
shared.state.begin('Interrogate batch') jobid = shared.state.begin('Interrogate batch')
prompts = [] prompts = []
if write: if write:
mode = 'w' if not append else 'a' mode = 'w' if not append else 'a'
@ -751,5 +750,5 @@ def batch(model_name, system_prompt, batch_files, batch_folder, batch_str, quest
if write: if write:
writer.close() writer.close()
shared.opts.interrogate_offload = orig_offload shared.opts.interrogate_offload = orig_offload
shared.state.end() shared.state.end(jobid)
return '\n\n'.join(prompts) return '\n\n'.join(prompts)

View File

@ -180,11 +180,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
if load_method == 'diffusers': if load_method == 'diffusers':
has_changed = False # diffusers handles its own loading has_changed = False # diffusers handles its own loading
if len(exclude) == 0: if len(exclude) == 0:
job = shared.state.job jobid = shared.state.begin('LoRA')
shared.state.job = 'LoRA'
lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load only on first call lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load only on first call
sd_models.set_diffuser_offload(shared.sd_model, op="model") sd_models.set_diffuser_offload(shared.sd_model, op="model")
shared.state.job = job shared.state.end(jobid)
elif load_method == 'nunchaku': elif load_method == 'nunchaku':
from modules.lora import lora_nunchaku from modules.lora import lora_nunchaku
has_changed = lora_nunchaku.load_nunchaku(names, unet_multipliers) has_changed = lora_nunchaku.load_nunchaku(names, unet_multipliers)
@ -192,16 +191,15 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load lora_load.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load
has_changed = self.changed(requested, include, exclude) has_changed = self.changed(requested, include, exclude)
if has_changed: if has_changed:
job = shared.state.job jobid = shared.state.begin('LoRA')
shared.state.job = 'LoRA'
if len(l.previously_loaded_networks) > 0: if len(l.previously_loaded_networks) > 0:
shared.log.info(f'Network unload: type=LoRA apply={[n.name for n in l.previously_loaded_networks]} mode={"fuse" if shared.opts.lora_fuse_diffusers else "backup"}') shared.log.info(f'Network unload: type=LoRA apply={[n.name for n in l.previously_loaded_networks]} mode={"fuse" if shared.opts.lora_fuse_diffusers else "backup"}')
networks.network_deactivate(include, exclude) networks.network_deactivate(include, exclude)
networks.network_activate(include, exclude) networks.network_activate(include, exclude)
if len(exclude) > 0: # only update on last activation if len(exclude) > 0: # only update on last activation
l.previously_loaded_networks = l.loaded_networks.copy() l.previously_loaded_networks = l.loaded_networks.copy()
shared.state.job = job
debug_log(f'Network load: type=LoRA previous={[n.name for n in l.previously_loaded_networks]} current={[n.name for n in l.loaded_networks]} changed') debug_log(f'Network load: type=LoRA previous={[n.name for n in l.previously_loaded_networks]} current={[n.name for n in l.loaded_networks]} changed')
shared.state.end(jobid)
if len(l.loaded_networks) > 0 and (len(networks.applied_layers) > 0 or load_method=='diffusers' or load_method=='nunchaku') and step == 0: if len(l.loaded_networks) > 0 and (len(networks.applied_layers) > 0 or load_method=='diffusers' or load_method=='nunchaku') and step == 0:
infotext(p) infotext(p)

View File

@ -135,7 +135,7 @@ def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
maxrank = int(maxrank) maxrank = int(maxrank)
rank_ratio = 1 if not auto_rank else rank_ratio rank_ratio = 1 if not auto_rank else rank_ratio
shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"') shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"')
shared.state.begin('LoRA extract') jobid = shared.state.begin('LoRA extract')
with rp.Progress(rp.TextColumn('[cyan]LoRA extract'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) as progress: with rp.Progress(rp.TextColumn('[cyan]LoRA extract'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) as progress:
@ -226,7 +226,7 @@ def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
yield msg yield msg
return return
shared.state.end() shared.state.end(jobid)
meta = make_meta(fn, maxrank, rank_ratio) meta = make_meta(fn, maxrank, rank_ratio)
shared.log.debug(f'LoRA metadata: {meta}') shared.log.debug(f'LoRA metadata: {meta}')
try: try:

View File

@ -28,7 +28,6 @@ def load_model(engine: str, model: str):
def load_upsample(upsample_pipe, upsample_repo_id): def load_upsample(upsample_pipe, upsample_repo_id):
if upsample_pipe is None: if upsample_pipe is None:
t0 = time.time() t0 = time.time()
shared.state.begin('Load')
from diffusers.pipelines.ltx.pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline from diffusers.pipelines.ltx.pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
shared.log.info(f'Video load: cls={LTXLatentUpsamplePipeline.__class__.__name__} repo="{upsample_repo_id}"') shared.log.info(f'Video load: cls={LTXLatentUpsamplePipeline.__class__.__name__} repo="{upsample_repo_id}"')
upsample_pipe = LTXLatentUpsamplePipeline.from_pretrained( upsample_pipe = LTXLatentUpsamplePipeline.from_pretrained(
@ -37,7 +36,6 @@ def load_upsample(upsample_pipe, upsample_repo_id):
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype, torch_dtype=devices.dtype,
) )
shared.state.end()
t1 = time.time() t1 = time.time()
timer.process.add('load', t1 - t0) timer.process.add('load', t1 - t0)
return upsample_pipe return upsample_pipe

View File

@ -53,7 +53,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
if hub_id is None or len(hub_id) == 0: if hub_id is None or len(hub_id) == 0:
return None return None
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
shared.state.begin('HuggingFace') jobid = shared.state.begin('Download')
if hub_id.startswith('huggingface/'): if hub_id.startswith('huggingface/'):
hub_id = hub_id.replace('huggingface/', '') hub_id = hub_id.replace('huggingface/', '')
if download_config is None: if download_config is None:
@ -98,9 +98,11 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
debug(f'Diffusers download error: id="{hub_id}" {e}') debug(f'Diffusers download error: id="{hub_id}" {e}')
if 'gated' in str(e): if 'gated' in str(e):
shared.log.error(f'Diffusers download error: id="{hub_id}" model access requires login') shared.log.error(f'Diffusers download error: id="{hub_id}" model access requires login')
shared.state.end(jobid)
return None return None
if pipeline_dir is None: if pipeline_dir is None:
shared.log.error(f'Diffusers download error: id="{hub_id}" {err}') shared.log.error(f'Diffusers download error: id="{hub_id}" {err}')
shared.state.end(jobid)
return None return None
try: try:
model_info_dict = hf.model_info(hub_id).cardData if pipeline_dir is not None else None model_info_dict = hf.model_info(hub_id).cardData if pipeline_dir is not None else None
@ -113,7 +115,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
f.write("True") f.write("True")
if pipeline_dir is not None: if pipeline_dir is not None:
shared.writefile(model_info_dict, os.path.join(pipeline_dir, "model_info.json")) shared.writefile(model_info_dict, os.path.join(pipeline_dir, "model_info.json"))
shared.state.end() shared.state.end(jobid)
return pipeline_dir return pipeline_dir

View File

@ -332,7 +332,6 @@ class YoloRestorer(Detailer):
mask_all = [] mask_all = []
p.state = '' p.state = ''
prev_state = shared.state.job
pc = copy(p) pc = copy(p)
pc.ops.append('detailer') pc.ops.append('detailer')
@ -348,8 +347,9 @@ class YoloRestorer(Detailer):
pc.image_mask = [item.mask] pc.image_mask = [item.mask]
pc.overlay_images = [] pc.overlay_images = []
pc.recursion = True pc.recursion = True
shared.state.job = 'Detailer' jobid = shared.state.begin('Detailer')
pp = processing.process_images_inner(pc) pp = processing.process_images_inner(pc)
shared.state.end(jobid)
del pc.recursion del pc.recursion
if pp is not None and pp.images is not None and len(pp.images) > 0: if pp is not None and pp.images is not None and len(pp.images) > 0:
image = pp.images[0] # update image to be reused for next item image = pp.images[0] # update image to be reused for next item
@ -369,7 +369,6 @@ class YoloRestorer(Detailer):
p.image_mask = orig_p.get('image_mask', None) p.image_mask = orig_p.get('image_mask', None)
p.state = orig_p.get('state', None) p.state = orig_p.get('state', None)
p.ops = orig_p.get('ops', []) p.ops = orig_p.get('ops', [])
shared.state.job = prev_state
shared.opts.data['mask_apply_overlay'] = orig_apply_overlay shared.opts.data['mask_apply_overlay'] = orig_apply_overlay
np_image = np.array(image) np_image = np.array(image)

View File

@ -369,7 +369,6 @@ def process_samples(p: StableDiffusionProcessing, samples):
def process_images_inner(p: StableDiffusionProcessing) -> Processed: def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list: if type(p.prompt) == list:
assert len(p.prompt) > 0 assert len(p.prompt) > 0
else: else:
@ -383,7 +382,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner): if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
p.scripts.process(p) p.scripts.process(p)
shared.state.begin('Process') jobid = shared.state.begin('Process')
shared.state.batch_count = p.n_iter shared.state.batch_count = p.n_iter
with devices.inference_context(): with devices.inference_context():
t0 = time.time() t0 = time.time()
@ -445,8 +444,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
batch_images, batch_infotexts = process_samples(p, samples) batch_images, batch_infotexts = process_samples(p, samples)
for batch_image, batch_infotext in zip(batch_images, batch_infotexts): for batch_image, batch_infotext in zip(batch_images, batch_infotexts):
output_images.append(batch_image) if batch_image is not None and batch_image not in output_images:
infotexts.append(batch_infotext) output_images.append(batch_image)
infotexts.append(batch_infotext)
if shared.cmd_opts.lowvram: if shared.cmd_opts.lowvram:
devices.torch_gc(force=True, reason='lowvram') devices.torch_gc(force=True, reason='lowvram')
@ -495,5 +495,5 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
devices.torch_gc(force=True, reason='final') devices.torch_gc(force=True, reason='final')
shared.state.end() shared.state.end(jobid)
return results return results

View File

@ -168,6 +168,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
'Chroma' in model.__class__.__name__ or 'Chroma' in model.__class__.__name__ or
'HiDreamImagePipeline' in model.__class__.__name__ 'HiDreamImagePipeline' in model.__class__.__name__
): ):
jobid = shared.state.begin('TE Encode')
try: try:
prompt_parser_diffusers.embedder = prompt_parser_diffusers.PromptEmbedder(prompts, negative_prompts, steps, clip_skip, p) prompt_parser_diffusers.embedder = prompt_parser_diffusers.PromptEmbedder(prompts, negative_prompts, steps, clip_skip, p)
parser = shared.opts.prompt_attention parser = shared.opts.prompt_attention
@ -176,6 +177,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
if os.environ.get('SD_PROMPT_DEBUG', None) is not None: if os.environ.get('SD_PROMPT_DEBUG', None) is not None:
errors.display(e, 'Prompt parser encode') errors.display(e, 'Prompt parser encode')
timer.process.record('prompt', reset=False) timer.process.record('prompt', reset=False)
shared.state.end(jobid)
else: else:
prompt_parser_diffusers.embedder = None prompt_parser_diffusers.embedder = None

View File

@ -118,7 +118,7 @@ def process_post(p: processing.StableDiffusionProcessing):
def process_base(p: processing.StableDiffusionProcessing): def process_base(p: processing.StableDiffusionProcessing):
shared.state.begin('Base') jobid = shared.state.begin('Base')
txt2img = is_txt2img() txt2img = is_txt2img()
use_refiner_start = is_refiner_enabled(p) and (not p.is_hr_pass) use_refiner_start = is_refiner_enabled(p) and (not p.is_hr_pass)
use_denoise_start = not txt2img and p.refiner_start > 0 and p.refiner_start < 1 use_denoise_start = not txt2img and p.refiner_start > 0 and p.refiner_start < 1
@ -164,7 +164,9 @@ def process_base(p: processing.StableDiffusionProcessing):
base_args['gate_step'] = p.gate_step base_args['gate_step'] = p.gate_step
output = shared.sd_model.tgate(**base_args) # pylint: disable=not-callable output = shared.sd_model.tgate(**base_args) # pylint: disable=not-callable
else: else:
taskid = shared.state.begin('Model')
output = shared.sd_model(**base_args) output = shared.sd_model(**base_args)
shared.state.end(taskid)
if isinstance(output, dict): if isinstance(output, dict):
output = SimpleNamespace(**output) output = SimpleNamespace(**output)
if isinstance(output, list): if isinstance(output, list):
@ -207,8 +209,8 @@ def process_base(p: processing.StableDiffusionProcessing):
finally: finally:
process_post(p) process_post(p)
shared.state.end(jobid)
shared.state.nextjob() shared.state.nextjob()
shared.state.end()
return output return output
@ -217,6 +219,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
if (output is None) or (output.images is None): if (output is None) or (output.images is None):
return output return output
if p.enable_hr: if p.enable_hr:
jobid = shared.state.begin('Hires')
p.is_hr_pass = True p.is_hr_pass = True
if hasattr(p, 'init_hr'): if hasattr(p, 'init_hr'):
p.init_hr(p.hr_scale, p.hr_upscaler, force=p.hr_force) p.init_hr(p.hr_scale, p.hr_upscaler, force=p.hr_force)
@ -228,7 +231,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
p.hr_resize_context = p.resize_context p.hr_resize_context = p.resize_context
p.hr_upscale_to_x = p.width * p.hr_scale if p.hr_resize_x == 0 else p.hr_resize_x p.hr_upscale_to_x = p.width * p.hr_scale if p.hr_resize_x == 0 else p.hr_resize_x
p.hr_upscale_to_y = p.height * p.hr_scale if p.hr_resize_y == 0 else p.hr_resize_y p.hr_upscale_to_y = p.height * p.hr_scale if p.hr_resize_y == 0 else p.hr_resize_y
prev_job = shared.state.job
# hires runs on original pipeline # hires runs on original pipeline
if hasattr(shared.sd_model, 'restore_pipeline') and (shared.sd_model.restore_pipeline is not None) and (not shared.opts.control_hires): if hasattr(shared.sd_model, 'restore_pipeline') and (shared.sd_model.restore_pipeline is not None) and (not shared.opts.control_hires):
@ -240,7 +242,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
p.ops.append('upscale') p.ops.append('upscale')
if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'): if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'):
save_intermediate(p, latents=output.images, suffix="-before-hires") save_intermediate(p, latents=output.images, suffix="-before-hires")
shared.state.update('Upscale', 0, 1)
output.images = resize_hires(p, latents=output.images) output.images = resize_hires(p, latents=output.images)
sd_hijack_hypertile.hypertile_set(p, hr=True) sd_hijack_hypertile.hypertile_set(p, hr=True)
elif torch.is_tensor(output.images) and output.images.shape[-1] == 3: # nhwc elif torch.is_tensor(output.images) and output.images.shape[-1] == 3: # nhwc
@ -295,7 +296,9 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
try: try:
if 'base' in p.skip: if 'base' in p.skip:
extra_networks.activate(p) extra_networks.activate(p)
taskid = shared.state.begin('Model')
output = shared.sd_model(**hires_args) # pylint: disable=not-callable output = shared.sd_model(**hires_args) # pylint: disable=not-callable
shared.state.end(taskid)
if isinstance(output, dict): if isinstance(output, dict):
output = SimpleNamespace(**output) output = SimpleNamespace(**output)
if hasattr(output, 'images'): if hasattr(output, 'images'):
@ -314,7 +317,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
if orig_image is not None: if orig_image is not None:
p.task_args['image'] = orig_image p.task_args['image'] = orig_image
p.denoising_strength = orig_denoise p.denoising_strength = orig_denoise
shared.state.job = prev_job shared.state.end(jobid)
shared.state.nextjob() shared.state.nextjob()
p.is_hr_pass = False p.is_hr_pass = False
timer.process.record('hires') timer.process.record('hires')
@ -326,7 +329,6 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
if (output is None) or (output.images is None): if (output is None) or (output.images is None):
return output return output
if is_refiner_enabled(p): if is_refiner_enabled(p):
prev_job = shared.state.job
if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'): if shared.opts.samples_save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'):
save_intermediate(p, latents=output.images, suffix="-before-refiner") save_intermediate(p, latents=output.images, suffix="-before-refiner")
if shared.opts.diffusers_move_base: if shared.opts.diffusers_move_base:
@ -335,6 +337,7 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
if shared.state.interrupted or shared.state.skipped: if shared.state.interrupted or shared.state.skipped:
shared.sd_model = orig_pipeline shared.sd_model = orig_pipeline
return output return output
jobid = shared.state.begin('Refine')
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
if shared.opts.diffusers_move_refiner: if shared.opts.diffusers_move_refiner:
sd_models.move_model(shared.sd_refiner, devices.device) sd_models.move_model(shared.sd_refiner, devices.device)
@ -400,7 +403,7 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
elif shared.opts.diffusers_move_refiner: elif shared.opts.diffusers_move_refiner:
shared.log.debug('Moving to CPU: model=refiner') shared.log.debug('Moving to CPU: model=refiner')
sd_models.move_model(shared.sd_refiner, devices.cpu) sd_models.move_model(shared.sd_refiner, devices.cpu)
shared.state.job = prev_job shared.state.end(jobid)
shared.state.nextjob() shared.state.nextjob()
p.is_refiner_pass = False p.is_refiner_pass = False
timer.process.record('refine') timer.process.record('refine')

View File

@ -205,8 +205,6 @@ def decode_first_stage(model, x):
shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}') shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device) x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
return x_sample return x_sample
prev_job = shared.state.job
shared.state.job = 'VAE'
with devices.autocast(disable = x.dtype==devices.dtype_vae): with devices.autocast(disable = x.dtype==devices.dtype_vae):
try: try:
if hasattr(model, 'decode_first_stage'): if hasattr(model, 'decode_first_stage'):
@ -220,7 +218,6 @@ def decode_first_stage(model, x):
except Exception as e: except Exception as e:
x_sample = x x_sample = x
shared.log.error(f'Decode VAE: {e}') shared.log.error(f'Decode VAE: {e}')
shared.state.job = prev_job
return x_sample return x_sample
@ -301,17 +298,21 @@ def resize_init_images(p):
def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler else latent def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler else latent
jobid = shared.state.begin('Resize')
if not torch.is_tensor(latents): if not torch.is_tensor(latents):
shared.log.warning('Hires: input is not tensor') shared.log.warning('Hires: input is not tensor')
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height) decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
shared.state.end(jobid)
return decoded return decoded
if (p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0) and hasattr(p, 'init_hr'): if (p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0) and hasattr(p, 'init_hr'):
shared.log.error('Hires: missing upscaling dimensions') shared.log.error('Hires: missing upscaling dimensions')
shared.state.end(jobid)
return decoded return decoded
if p.hr_upscaler.lower().startswith('latent'): if p.hr_upscaler.lower().startswith('latent'):
resized = images.resize_image(p.hr_resize_mode, latents, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context) resized = images.resize_image(p.hr_resize_mode, latents, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context)
shared.state.end(jobid)
return resized return resized
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height) decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
@ -320,6 +321,7 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler
resize = images.resize_image(p.hr_resize_mode, image, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context) resize = images.resize_image(p.hr_resize_mode, image, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context)
resized.append(resize) resized.append(resize)
devices.torch_gc() devices.torch_gc()
shared.state.end(jobid)
return resized return resized

View File

@ -268,17 +268,16 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
model = model.pipe model = model.pipe
if latents is None or not torch.is_tensor(latents): # already decoded if latents is None or not torch.is_tensor(latents): # already decoded
return latents return latents
prev_job = shared.state.job
if vae_type == 'Remote': if vae_type == 'Remote':
shared.state.job = 'Remote VAE' jobid = shared.state.begin('Remote VAE')
from modules.sd_vae_remote import remote_decode from modules.sd_vae_remote import remote_decode
tensors = remote_decode(latents=latents, width=width, height=height) tensors = remote_decode(latents=latents, width=width, height=height)
shared.state.job = prev_job shared.state.end(jobid)
if tensors is not None and len(tensors) > 0: if tensors is not None and len(tensors) > 0:
return vae_postprocess(tensors, model, output_type) return vae_postprocess(tensors, model, output_type)
shared.state.job = 'VAE' jobid = shared.state.begin('VAE Decode')
if latents.shape[0] == 0: if latents.shape[0] == 0:
shared.log.error(f'VAE nothing to decode: {latents.shape}') shared.log.error(f'VAE nothing to decode: {latents.shape}')
return [] return []
@ -308,11 +307,11 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
decoded = 2.0 * decoded - 1.0 # typical normalized range decoded = 2.0 * decoded - 1.0 # typical normalized range
images = vae_postprocess(decoded, model, output_type) images = vae_postprocess(decoded, model, output_type)
shared.state.job = prev_job
if shared.cmd_opts.profile or debug: if shared.cmd_opts.profile or debug:
t1 = time.time() t1 = time.time()
shared.log.debug(f'Profile: VAE decode: {t1-t0:.2f}') shared.log.debug(f'Profile: VAE decode: {t1-t0:.2f}')
devices.torch_gc() devices.torch_gc()
shared.state.end(jobid)
return images return images

View File

@ -96,13 +96,14 @@ class ScriptPostprocessingRunner:
def run(self, pp: PostprocessedImage, args): def run(self, pp: PostprocessedImage, args):
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
shared.state.job = script.name jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
shared.log.debug(f'Process: script={script.name} args={process_args}') shared.log.debug(f'Process: script={script.name} args={process_args}')
script.process(pp, **process_args) script.process(pp, **process_args)
shared.state.end(jobid)
def create_args_for_run(self, scripts_args): def create_args_for_run(self, scripts_args):
if not self.ui_created: if not self.ui_created:
@ -125,10 +126,11 @@ class ScriptPostprocessingRunner:
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
if not hasattr(script, 'postprocess'): if not hasattr(script, 'postprocess'):
continue continue
shared.state.job = script.name jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
shared.log.debug(f'Postprocess: script={script.name} args={process_args}') shared.log.debug(f'Postprocess: script={script.name} args={process_args}')
script.postprocess(filenames, **process_args) script.postprocess(filenames, **process_args)
shared.state.end(jobid)

View File

@ -4,7 +4,7 @@ from modules import shared, errors, timer, sd_models
def hijack_encode_prompt(*args, **kwargs): def hijack_encode_prompt(*args, **kwargs):
shared.state.begin('TE') jobid = shared.state.begin('TE Encode')
t0 = time.time() t0 = time.time()
if 'max_sequence_length' in kwargs and kwargs['max_sequence_length'] is not None: if 'max_sequence_length' in kwargs and kwargs['max_sequence_length'] is not None:
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('HIDREAM_MAX_SEQUENCE_LENGTH', 256)) kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('HIDREAM_MAX_SEQUENCE_LENGTH', 256))
@ -22,7 +22,7 @@ def hijack_encode_prompt(*args, **kwargs):
# if hasattr(shared.sd_model, "maybe_free_model_hooks"): # if hasattr(shared.sd_model, "maybe_free_model_hooks"):
# shared.sd_model.maybe_free_model_hooks() # shared.sd_model.maybe_free_model_hooks()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
shared.state.end() shared.state.end(jobid)
return res return res

View File

@ -8,7 +8,7 @@ debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None e
def hijack_vae_decode(*args, **kwargs): def hijack_vae_decode(*args, **kwargs):
shared.state.begin('VAE') jobid = shared.state.begin('VAE Decode')
t0 = time.time() t0 = time.time()
res = None res = None
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae']) shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae'])
@ -27,12 +27,12 @@ def hijack_vae_decode(*args, **kwargs):
res = None res = None
t1 = time.time() t1 = time.time()
timer.process.add('vae', t1-t0) timer.process.add('vae', t1-t0)
shared.state.end() shared.state.end(jobid)
return res return res
def hijack_vae_encode(*args, **kwargs): def hijack_vae_encode(*args, **kwargs):
shared.state.begin('VAE') jobid = shared.state.begin('VAE Encode')
t0 = time.time() t0 = time.time()
res = None res = None
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae']) shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae'])
@ -51,7 +51,7 @@ def hijack_vae_encode(*args, **kwargs):
res = None res = None
t1 = time.time() t1 = time.time()
timer.process.add('vae', t1-t0) timer.process.add('vae', t1-t0)
shared.state.end() shared.state.end(jobid)
return res return res

View File

@ -1122,8 +1122,7 @@ def reload_model_weights(sd_model=None, info=None, op='model', force=False, revi
if checkpoint_info is None: if checkpoint_info is None:
unload_model_weights(op=op) unload_model_weights(op=op)
return None return None
orig_state = copy.deepcopy(shared.state) jobid = shared.state.begin('Load')
shared.state.begin('Load')
if sd_model is None: if sd_model is None:
sd_model = model_data.sd_model if op == 'model' or op == 'dict' else model_data.sd_refiner sd_model = model_data.sd_model if op == 'model' or op == 'dict' else model_data.sd_refiner
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
@ -1131,6 +1130,7 @@ def reload_model_weights(sd_model=None, info=None, op='model', force=False, revi
else: else:
current_checkpoint_info = getattr(sd_model, 'sd_checkpoint_info', None) current_checkpoint_info = getattr(sd_model, 'sd_checkpoint_info', None)
if current_checkpoint_info is not None and checkpoint_info is not None and current_checkpoint_info.filename == checkpoint_info.filename and not force: if current_checkpoint_info is not None and checkpoint_info is not None and current_checkpoint_info.filename == checkpoint_info.filename and not force:
shared.state.end(jobid)
return None return None
else: else:
move_model(sd_model, devices.cpu) move_model(sd_model, devices.cpu)
@ -1142,14 +1142,14 @@ def reload_model_weights(sd_model=None, info=None, op='model', force=False, revi
if sd_model is None or force: if sd_model is None or force:
sd_model = None sd_model = None
load_diffuser(checkpoint_info, op=op, revision=revision) load_diffuser(checkpoint_info, op=op, revision=revision)
shared.state.end() shared.state.end(jobid)
shared.state = orig_state
if op == 'model': if op == 'model':
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
return model_data.sd_model return model_data.sd_model
else: else:
shared.opts.data["sd_model_refiner"] = checkpoint_info.title shared.opts.data["sd_model_refiner"] = checkpoint_info.title
return model_data.sd_refiner return model_data.sd_refiner
shared.state.end(jobid)
return None # should not be here return None # should not be here

View File

@ -12,9 +12,9 @@ debug_history = debug_output or os.environ.get('SD_STATE_HISTORY', None)
class State: class State:
job_history = []
task_history = []
state_history = [] state_history = []
job_history = 0
task_history = 0
image_history = 0 image_history = 0
latent_history = 0 latent_history = 0
id = 0 id = 0
@ -45,7 +45,6 @@ class State:
disable_preview = False disable_preview = False
preview_job = -1 preview_job = -1
time_start = None time_start = None
time_end = None
need_restart = False need_restart = False
server_start = time.time() server_start = time.time()
oom = False oom = False
@ -142,8 +141,14 @@ class State:
res.status = 'running' if self.job != '' else 'idle' res.status = 'running' if self.job != '' else 'idle'
return res return res
def history(self, op:str): def find(self, task_id:str):
job = { 'id': self.id, 'job': self.job.lower(), 'op': op.lower(), 'start': self.time_start, 'end': self.time_end, 'outputs': self.results } for job in reversed(self.state_history):
if job['id'] == task_id:
return job
return None
def history(self, op:str, task_id:str=None, results:list=[]):
job = { 'id': task_id or self.id, 'job': self.job.lower(), 'op': op.lower(), 'timestamp': self.time_start, 'outputs': results }
self.state_history.append(job) self.state_history.append(job)
l = len(self.state_history) l = len(self.state_history)
if l > 10000: if l > 10000:
@ -156,6 +161,8 @@ class State:
self.results += results self.results += results
else: else:
self.results.append(results) self.results.append(results)
if len(self.results) > 0:
self.history('output', self.id, results=self.results)
def get_id(self, task_id:str=None): def get_id(self, task_id:str=None):
if task_id is None or task_id == 0: if task_id is None or task_id == 0:
@ -165,52 +172,7 @@ class State:
match = re.search(r'\((.*?)\)', task_id) match = re.search(r'\((.*?)\)', task_id)
return match.group(1) if match else task_id return match.group(1) if match else task_id
def begin(self, title="", task_id=0, api=None): def clear(self):
import modules.devices
self.job_history.append(title)
self.total_jobs += 1
self.current_image = None
self.current_image_sampling_step = 0
self.current_latent = None
self.current_noise_pred = None
self.current_sigma = None
self.current_sigma_next = None
self.id_live_preview = 0
self.interrupted = False
self.preview_job = -1
self.results = []
self.id = self.get_id(task_id)
self.job = title
self.job_count = 1 # cannot be less than 1 on new job
self.frame_count = 0
self.batch_no = 0
self.batch_count = 0
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.paused = False
self._sampling_step = 0
self.sampling_steps = 0
self.skipped = False
self.textinfo = None
self.prediction_type = "epsilon"
self.api = api or self.api
self.time_start = time.time()
self.time_end = None
self.history('begin')
if debug_output:
log.trace(f'State begin: {self}')
modules.devices.torch_gc()
def end(self, api=None):
import modules.devices
if self.time_start is None: # someone called end before being
# fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
# log.debug(f'Access state.end: {fn}') # pylint: disable=protected-access
self.time_start = time.time()
if debug_output:
log.trace(f'State end: {self}')
self.time_end = time.time()
self.history('end')
self.id = '' self.id = ''
self.job = '' self.job = ''
self.job_count = 0 self.job_count = 0
@ -220,14 +182,57 @@ class State:
self.paused = False self.paused = False
self.interrupted = False self.interrupted = False
self.skipped = False self.skipped = False
self.results = []
def begin(self, title="", task_id=0, api=None):
import modules.devices
self.clear()
self.job_history += 1
self.total_jobs += 1
self.current_image = None
self.current_image_sampling_step = 0
self.current_latent = None
self.current_noise_pred = None
self.current_sigma = None
self.current_sigma_next = None
self.id_live_preview = 0
self.id = self.get_id(task_id)
self.job = title
self.job_count = 1 # cannot be less than 1 on new job
self.batch_no = 0
self.batch_count = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self._sampling_step = 0
self.sampling_steps = 0
self.textinfo = None
self.prediction_type = "epsilon"
self.api = api or self.api self.api = api or self.api
self.time_start = time.time()
self.history('begin', self.id)
if debug_output:
log.trace(f'State begin: {self}')
modules.devices.torch_gc()
return self.id
def end(self, task_id=None):
import modules.devices
if debug_output:
log.trace(f'State end: {self}')
if task_id is not None:
prev_job = self.find(task_id)
if prev_job is not None:
self.id = prev_job['id']
self.job = prev_job['job']
self.time_start = time.time()
self.history('end', task_id or self.id)
self.clear()
modules.devices.torch_gc() modules.devices.torch_gc()
def step(self, step:int=1): def step(self, step:int=1):
self.sampling_step += step self.sampling_step += step
def update(self, job:str, steps:int=0, jobs:int=0): def update(self, job:str, steps:int=0, jobs:int=0):
self.task_history.append(job) self.task_history += 1
# self._sampling_step = 0 # self._sampling_step = 0
if job == 'Ignore': if job == 'Ignore':
return return
@ -237,8 +242,7 @@ class State:
else: else:
self.sampling_steps += (steps * jobs) self.sampling_steps += (steps * jobs)
self.job_count += jobs self.job_count += jobs
self.job = job # self.job = job
self.history('update')
if debug_output: if debug_output:
log.trace(f'State update: {self} steps={steps} jobs={jobs}') log.trace(f'State update: {self} steps={steps} jobs={jobs}')

View File

@ -101,6 +101,7 @@ def generate_click(job_id: str, state: str, active_tab: str, *args):
with call_queue.queue_lock: with call_queue.queue_lock:
yield [None, None, None, None, 'Control: starting', ''] yield [None, None, None, None, 'Control: starting', '']
shared.mem_mon.reset() shared.mem_mon.reset()
jobid = shared.state.begin('Control')
progress.start_task(job_id) progress.start_task(job_id)
try: try:
t = time.perf_counter() t = time.perf_counter()
@ -112,7 +113,7 @@ def generate_click(job_id: str, state: str, active_tab: str, *args):
errors.display(e, 'Control') errors.display(e, 'Control')
yield [None, None, None, None, f'Control: Exception: {e}', ''] yield [None, None, None, None, f'Control: Exception: {e}', '']
progress.finish_task(job_id) progress.finish_task(job_id)
shared.state.end() shared.state.end(jobid)
def create_ui(_blocks: gr.Blocks=None): def create_ui(_blocks: gr.Blocks=None):

View File

@ -6,7 +6,7 @@ from modules import shared
def refresh(): def refresh():
def ts(t): def ts(t):
try: try:
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t)) return time.strftime('%Y-%m-%d %H:%M:%S.%f', time.localtime(t))
except Exception: except Exception:
return '' return ''
@ -16,8 +16,7 @@ def refresh():
item['id'], item['id'],
item['job'], item['job'],
item['op'], item['op'],
ts(item['start']), ts(item['timestamp']),
ts(item['end']),
len(item['outputs']), len(item['outputs']),
]) ])
shared.log.info(f"History: records={len(items)}") shared.log.info(f"History: records={len(items)}")
@ -42,7 +41,7 @@ def create_ui():
with gr.Row(): with gr.Row():
history_table = gr.DataFrame( history_table = gr.DataFrame(
value=None, value=None,
headers=['ID', 'Job', 'Op', 'Start', 'End', 'Outputs'], headers=['ID', 'Job', 'Op', 'Timestamp', 'Outputs'],
label='History data', label='History data',
show_label=True, show_label=True,
interactive=False, interactive=False,

View File

@ -90,8 +90,7 @@ class Upscaler:
return img return img
def upscale(self, img: Image, scale, selected_model: str = None): def upscale(self, img: Image, scale, selected_model: str = None):
orig_state = copy.deepcopy(shared.state) jobid = shared.state.begin('Upscale')
shared.state.begin('Upscale')
self.scale = scale self.scale = scale
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
dest_w = int(img.width * scale) dest_w = int(img.width * scale)
@ -111,8 +110,7 @@ class Upscaler:
break break
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=Image.Resampling.LANCZOS) img = img.resize((int(dest_w), int(dest_h)), resample=Image.Resampling.LANCZOS)
shared.state.end() shared.state.end(jobid)
shared.state = orig_state
return img return img
@abstractmethod @abstractmethod

View File

@ -8,7 +8,6 @@ loaded_model = None
def load_model(selected: models_def.Model): def load_model(selected: models_def.Model):
shared.state.begin('Load')
if selected is None: if selected is None:
return '' return ''
global loaded_model # pylint: disable=global-statement global loaded_model # pylint: disable=global-statement
@ -16,6 +15,7 @@ def load_model(selected: models_def.Model):
return '' return ''
sd_models.unload_model_weights() sd_models.unload_model_weights()
t0 = time.time() t0 = time.time()
jobid = shared.state.begin('Load')
video_cache.apply_teacache_patch(selected.dit_cls) video_cache.apply_teacache_patch(selected.dit_cls)
@ -111,5 +111,5 @@ def load_model(selected: models_def.Model):
loaded_model = selected.name loaded_model = selected.name
msg = f'Video load: cls={shared.sd_model.__class__.__name__} model="{selected.name}" time={t1-t0:.2f}' msg = f'Video load: cls={shared.sd_model.__class__.__name__} model="{selected.name}" time={t1-t0:.2f}'
shared.log.info(msg) shared.log.info(msg)
shared.state.end() shared.state.end(jobid)
return msg return msg

View File

@ -184,7 +184,6 @@ class Script(scripts_manager.Script):
# auto-executed by the script-callback # auto-executed by the script-callback
def run(self, p: processing.StableDiffusionProcessing, model, sampler, frames, guidance, offload, override, video_type, duration, loop, pad, interpolate, image, video): # pylint: disable=arguments-differ, unused-argument def run(self, p: processing.StableDiffusionProcessing, model, sampler, frames, guidance, offload, override, video_type, duration, loop, pad, interpolate, image, video): # pylint: disable=arguments-differ, unused-argument
shared.state.begin('CogVideoX')
processing.fix_seed(p) processing.fix_seed(p)
p.extra_generation_params['CogVideoX'] = model p.extra_generation_params['CogVideoX'] = model
p.do_not_save_grid = True p.do_not_save_grid = True
@ -206,7 +205,6 @@ class Script(scripts_manager.Script):
frames = self.generate(p, model) frames = self.generate(p, model)
devices.torch_gc() devices.torch_gc()
processed = processing.get_processed(p, images_list=frames) processed = processing.get_processed(p, images_list=frames)
shared.state.end()
return processed return processed
# auto-executed by the script-callback # auto-executed by the script-callback

View File

@ -67,7 +67,6 @@ class Script(scripts_manager.Script):
p.do_not_save_grid = True p.do_not_save_grid = True
if opts.img2img_color_correction: if opts.img2img_color_correction:
p.color_corrections = initial_color_corrections p.color_corrections = initial_color_corrections
state.job = f"loopback iteration {i+1}/{loops} batch {n+1}/{initial_batch_count}"
processed = processing.process_images(p) processed = processing.process_images(p)
if processed is None: if processed is None:
log.error("Loopback: processing output is none") log.error("Loopback: processing output is none")

View File

@ -208,7 +208,6 @@ class Script(scripts_manager.Script):
all_processed_images = [] all_processed_images = []
for i in range(batch_count): for i in range(batch_count):
imgs = [init_img] * batch_size imgs = [init_img] * batch_size
state.job = f"outpainting batch {i+1}/{batch_count}"
if left > 0: if left > 0:
imgs = expand(imgs, batch_size, left, is_left=True) imgs = expand(imgs, batch_size, left, is_left=True)
if right > 0: if right > 0:

View File

@ -90,7 +90,6 @@ class Script(scripts_manager.Script):
p.init_images = [work[i]] p.init_images = [work[i]]
p.image_mask = work_mask[i] p.image_mask = work_mask[i]
p.latent_mask = work_latent_mask[i] p.latent_mask = work_latent_mask[i]
state.job = f"outpainting batch {i+1}/{batch_count}"
processed = process_images(p) processed = process_images(p)
if initial_seed is None: if initial_seed is None:
initial_seed = processed.seed initial_seed = processed.seed

View File

@ -514,7 +514,7 @@ class Script(scripts_manager.Script):
p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
shared.prompt_styles.apply_styles_to_extra(p) shared.prompt_styles.apply_styles_to_extra(p)
p.styles = [] p.styles = []
shared.state.begin('LLM') jobid = shared.state.begin('LLM')
p.prompt = self.enhance( p.prompt = self.enhance(
prompt=p.prompt, prompt=p.prompt,
seed=p.seed, seed=p.seed,
@ -532,4 +532,4 @@ class Script(scripts_manager.Script):
) )
timer.process.record('prompt') timer.process.record('prompt')
p.extra_generation_params['LLM'] = llm_model p.extra_generation_params['LLM'] = llm_model
shared.state.end() shared.state.end(jobid)

View File

@ -136,7 +136,6 @@ class Script(scripts_manager.Script):
all_negative = [] all_negative = []
infotexts = [] infotexts = []
for args in jobs: for args in jobs:
state.job = f"{state.job_no + 1} out of {state.job_count}"
copy_p = copy.copy(p) copy_p = copy.copy(p)
for k, v in args.items(): for k, v in args.items():
setattr(copy_p, k, v) setattr(copy_p, k, v)

View File

@ -252,7 +252,7 @@ class Script(scripts_manager.Script):
p.seed = processing_helpers.get_fixed_seed(p.seed) p.seed = processing_helpers.get_fixed_seed(p.seed)
if direct: # run pipeline directly if direct: # run pipeline directly
shared.state.begin('PuLID') jobid = shared.state.begin('PuLID')
processing.fix_seed(p) processing.fix_seed(p)
p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles) p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
@ -273,7 +273,7 @@ class Script(scripts_manager.Script):
)[0] )[0]
info = processing.create_infotext(p) info = processing.create_infotext(p)
processed = processing.get_processed(p, [output], info=info) processed = processing.get_processed(p, [output], info=info)
shared.state.end() shared.state.end(jobid)
else: # let processing run the pipeline else: # let processing run the pipeline
p.task_args['id_embedding'] = id_embedding p.task_args['id_embedding'] = id_embedding
p.task_args['uncond_id_embedding'] = uncond_id_embedding p.task_args['uncond_id_embedding'] = uncond_id_embedding

View File

@ -70,7 +70,6 @@ class Script(scripts_manager.Script):
for i in range(batch_count): for i in range(batch_count):
p.batch_size = batch_size p.batch_size = batch_size
p.init_images = work[i * batch_size:(i + 1) * batch_size] p.init_images = work[i * batch_size:(i + 1) * batch_size]
state.job = f"upscale batch {i+1+n*batch_count}/{state.job_count}"
processed = processing.process_images(p) processed = processing.process_images(p)
if initial_info is None: if initial_info is None:
initial_info = processed.info initial_info = processed.info

View File

@ -167,6 +167,7 @@ class Script(scripts_manager.Script):
include_time, include_text, margin_size, include_time, include_text, margin_size,
create_video, video_type, video_duration, video_loop, video_pad, video_interpolate, create_video, video_type, video_duration, video_loop, video_pad, video_interpolate,
): # pylint: disable=W0221 ): # pylint: disable=W0221
jobid = shared.state.begin('XYZ Grid')
if not no_fixed_seeds: if not no_fixed_seeds:
processing.fix_seed(p) processing.fix_seed(p)
if not shared.opts.return_grid: if not shared.opts.return_grid:
@ -348,7 +349,7 @@ class Script(scripts_manager.Script):
return processed, t1-t0 return processed, t1-t0
with SharedSettingsStackHelper(): with SharedSettingsStackHelper():
processed = draw_xyz_grid( processed: processing.Processed = draw_xyz_grid(
p, p,
xs=xs, xs=xs,
ys=ys, ys=ys,
@ -404,4 +405,5 @@ class Script(scripts_manager.Script):
if create_video and video_type != 'None' and not shared.state.interrupted: if create_video and video_type != 'None' and not shared.state.interrupted:
images.save_video(p, filename=None, images=have_images, video_type=video_type, duration=video_duration, loop=video_loop, pad=video_pad, interpolate=video_interpolate) images.save_video(p, filename=None, images=have_images, video_type=video_type, duration=video_duration, loop=video_loop, pad=video_pad, interpolate=video_interpolate)
shared.state.end(jobid)
return processed return processed

View File

@ -184,6 +184,7 @@ class Script(scripts_manager.Script):
processing.fix_seed(p) processing.fix_seed(p)
if not shared.opts.return_grid: if not shared.opts.return_grid:
p.batch_size = 1 p.batch_size = 1
jobid = shared.state.begin('XYZ Grid')
def process_axis(opt, vals, vals_dropdown): def process_axis(opt, vals, vals_dropdown):
if opt.label == 'Nothing': if opt.label == 'Nothing':
@ -430,6 +431,7 @@ class Script(scripts_manager.Script):
p.disable_extra_networks = True p.disable_extra_networks = True
active = False active = False
xyz_results_cache = processed xyz_results_cache = processed
shared.state.end(jobid)
return processed return processed

View File

@ -138,7 +138,7 @@ def initialize():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(_sig, _frame): def sigint_handler(_sig, _frame):
log.trace(f'State history: uptime={round(time.time() - shared.state.server_start)} jobs={len(shared.state.job_history)} tasks={len(shared.state.task_history)} latents={shared.state.latent_history} images={shared.state.image_history}') log.trace(f'State history: uptime={round(time.time() - shared.state.server_start)} jobs={shared.state.job_history} tasks={shared.state.task_history} latents={shared.state.latent_history} images={shared.state.image_history}')
log.info('Exiting') log.info('Exiting')
try: try:
for f in glob.glob("*.lock"): for f in glob.glob("*.lock"):
@ -155,14 +155,14 @@ def load_model():
if not shared.opts.sd_checkpoint_autoload and shared.cmd_opts.ckpt is None: if not shared.opts.sd_checkpoint_autoload and shared.cmd_opts.ckpt is None:
log.info('Model: autoload=False') log.info('Model: autoload=False')
else: else:
shared.state.begin('Load') jobid = shared.state.begin('Load')
thread_model = Thread(target=lambda: shared.sd_model) thread_model = Thread(target=lambda: shared.sd_model)
thread_model.start() thread_model.start()
thread_refiner = Thread(target=lambda: shared.sd_refiner) thread_refiner = Thread(target=lambda: shared.sd_refiner)
thread_refiner.start() thread_refiner.start()
thread_model.join() thread_model.join()
thread_refiner.join() thread_refiner.join()
shared.state.end() shared.state.end(jobid)
timer.startup.record("checkpoint") timer.startup.record("checkpoint")
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False)
shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False) shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False)