refactor: split legacy loaders

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3704/head
Vladimir Mandic 2025-01-13 13:00:30 -05:00
parent 1c10e69276
commit 0c8044070a
17 changed files with 543 additions and 522 deletions

View File

@ -22,7 +22,8 @@
- refactored progress monitoring, job updates and live preview
- improved metadata save and restore
- startup tracing and optimizations
- threading load locks on model loads
- threading load locks on model loads
- refactor native vs legacy model loader
- **Schedulers**:
- [TDD](https://github.com/RedAIGC/Target-Driven-Distillation) new super-fast scheduler that can generate images in 4-8 steps
recommended to use with [TDD LoRA](https://huggingface.co/RED-AIGC/TDD/tree/main)
@ -40,7 +41,7 @@
- **XYZ Grid**: add prompt search&replace options: *primary, refine, detailer, all*
- **SysInfo**: update to collected data and benchmarks
- [Wiki/Docs](https://vladmandic.github.io/sdnext-docs/):
- updated: Detailer, Install, Debug, Control-HowTo, ZLUDA
- updated: Detailer, Install, Update, Debug, Control-HowTo, ZLUDA
- **Fixes**:
- explict clear caches on model load
- lock adetailer commit: `#a89c01d`
@ -61,6 +62,8 @@
- restore args after batch run
- flux controlnet
- zluda installer
- control inherit parent pipe settings
- control logging
## Update for 2024-12-31

@ -1 +1 @@
Subproject commit cd878626f3b4f9a0c7c45c7d70b73a6168f612a4
Subproject commit a33753321b914c6122df96d1dc0b5117d38af680

View File

@ -150,6 +150,11 @@ def setup_logging():
log.addHandler(rb)
log.buffer = rb.buffer
def quiet_log(quiet: bool=False, *args, **kwargs): # pylint: disable=redefined-outer-name,keyword-arg-before-vararg
if not quiet:
log.debug(*args, **kwargs)
log.quiet = quiet_log
# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)

View File

@ -78,7 +78,6 @@ def compatibility_args():
group_compat.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help=argparse.SUPPRESS)
def settings_args(opts, args):
# removed args are added here as hidden in fixed format for compatbility reasons
group_compat = parser.add_argument_group('Compatibility options')

View File

@ -17,10 +17,11 @@ from modules import devices, shared, errors, processing, images, sd_models, scri
from modules.processing_class import StableDiffusionProcessingControl
from modules.ui_common import infotext_to_html
from modules.api import script
from modules.timer import process as process_timer
debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None
debug_log = shared.log.trace if debug else lambda *args, **kwargs: None
pipe = None
instance = None
original_pipeline = None
@ -32,7 +33,7 @@ def restore_pipeline():
if instance is not None and hasattr(instance, 'restore'):
instance.restore()
if original_pipeline is not None and (original_pipeline.__class__.__name__ != shared.sd_model.__class__.__name__):
debug(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}')
debug_log(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}')
shared.sd_model = original_pipeline
pipe = None
instance = None
@ -109,7 +110,7 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str
p.strength = active_strength[0]
pipe = shared.sd_model
instance = None
debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}')
debug_log(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}')
return pipe
@ -124,14 +125,14 @@ def check_active(p, unit_type, units):
if u.type != unit_type:
continue
num_units += 1
debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}')
debug_log(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}')
if not u.enabled:
if u.controlnet is not None and u.controlnet.model is not None:
debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}')
debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}')
sd_models.move_model(u.controlnet.model, devices.cpu)
continue
if u.controlnet is not None and u.controlnet.model is not None:
debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}')
debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}')
sd_models.move_model(u.controlnet.model, devices.device)
if unit_type == 't2i adapter' and u.adapter.model is not None:
active_process.append(u.process)
@ -176,7 +177,7 @@ def check_active(p, unit_type, units):
active_process.append(u.process)
shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}')
active_strength.append(float(u.strength))
debug(f'Control active: process={len(active_process)} model={len(active_model)}')
debug_log(f'Control active: process={len(active_process)} model={len(active_model)}')
return active_process, active_model, active_strength, active_start, active_end
@ -213,7 +214,7 @@ def control_set(kwargs):
if kwargs:
global p_extra_args # pylint: disable=global-statement
p_extra_args = {}
debug(f'Control extra args: {kwargs}')
debug_log(f'Control extra args: {kwargs}')
for k, v in kwargs.items():
p_extra_args[k] = v
@ -254,7 +255,7 @@ def control_run(state: str = '',
u.process.override = u.override
global pipe, original_pipeline # pylint: disable=global-statement
debug(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}')
debug_log(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}')
if inputs is None or (type(inputs) is list and len(inputs) == 0):
inputs = [None]
output_images: List[Image.Image] = [] # output images
@ -402,7 +403,7 @@ def control_run(state: str = '',
p.is_tile = p.is_tile and has_models
pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits)
debug(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}')
debug_log(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}')
t1, t2, t3 = time.time(), 0, 0
status = True
frame = None
@ -420,7 +421,7 @@ def control_run(state: str = '',
shared.sd_model = pipe
sd_models.move_model(shared.sd_model, shared.device)
shared.sd_model.to(dtype=devices.dtype)
debug(f'Control device={devices.device} dtype={devices.dtype}')
debug_log(f'Control device={devices.device} dtype={devices.dtype}')
sd_models.copy_diffuser_options(shared.sd_model, original_pipeline) # copy options from original pipeline
sd_models.set_diffuser_options(shared.sd_model)
else:
@ -458,12 +459,12 @@ def control_run(state: str = '',
while status:
if pipe is None: # pipe may have been reset externally
pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits)
debug(f'Control pipeline reinit: class={pipe.__class__.__name__}')
debug_log(f'Control pipeline reinit: class={pipe.__class__.__name__}')
processed_image = None
if frame is not None:
inputs = [Image.fromarray(frame)] # cv2 to pil
for i, input_image in enumerate(inputs):
debug(f'Control Control image: {i + 1} of {len(inputs)}')
debug_log(f'Control Control image: {i + 1} of {len(inputs)}')
if shared.state.skipped:
shared.state.skipped = False
continue
@ -481,20 +482,20 @@ def control_run(state: str = '',
continue
# match init input
if input_type == 1:
debug('Control Init image: same as control')
debug_log('Control Init image: same as control')
init_image = input_image
elif inits is None:
debug('Control Init image: none')
debug_log('Control Init image: none')
init_image = None
elif isinstance(inits[i], str):
debug(f'Control: init image: {inits[i]}')
debug_log(f'Control: init image: {inits[i]}')
try:
init_image = Image.open(inits[i])
except Exception as e:
shared.log.error(f'Control: image open failed: path={inits[i]} type=init error={e}')
continue
else:
debug(f'Control Init image: {i % len(inits) + 1} of {len(inits)}')
debug_log(f'Control Init image: {i % len(inits) + 1} of {len(inits)}')
init_image = inits[i % len(inits)]
if video is not None and index % (video_skip_frames + 1) != 0:
index += 1
@ -507,18 +508,18 @@ def control_run(state: str = '',
width_before, height_before = int(input_image.width * scale_by_before), int(input_image.height * scale_by_before)
if input_image is not None:
p.extra_generation_params["Control resize"] = f'{resize_name_before}'
debug(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"')
debug_log(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"')
input_image = images.resize_image(resize_mode_before, input_image, width_before, height_before, resize_name_before, context=resize_context_before)
if input_image is not None and init_image is not None and init_image.size != input_image.size:
debug(f'Control resize init: image={init_image} target={input_image}')
debug_log(f'Control resize init: image={init_image} target={input_image}')
init_image = images.resize_image(resize_mode=1, im=init_image, width=input_image.width, height=input_image.height)
if input_image is not None and p.override is not None and p.override.size != input_image.size:
debug(f'Control resize override: image={p.override} target={input_image}')
debug_log(f'Control resize override: image={p.override} target={input_image}')
p.override = images.resize_image(resize_mode=1, im=p.override, width=input_image.width, height=input_image.height)
if input_image is not None:
p.width = input_image.width
p.height = input_image.height
debug(f'Control: input image={input_image}')
debug_log(f'Control: input image={input_image}')
processed_images = []
if mask is not None:
@ -533,7 +534,7 @@ def control_run(state: str = '',
else:
masked_image = input_image
for i, process in enumerate(active_process): # list[image]
debug(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}')
debug_log(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}')
processed_image = process(
image_input=masked_image,
mode='RGB',
@ -548,7 +549,7 @@ def control_run(state: str = '',
processors.config[process.processor_id]['dirty'] = True # to force reload
process.model = None
debug(f'Control processed: {len(processed_images)}')
debug_log(f'Control processed: {len(processed_images)}')
if len(processed_images) > 0:
try:
if len(p.extra_generation_params["Control process"]) == 0:
@ -574,7 +575,7 @@ def control_run(state: str = '',
blended_image = util.blend(blended_image) # blend all processed images into one
blended_image = Image.fromarray(blended_image)
if isinstance(selected_models, list) and len(processed_images) == len(selected_models):
debug(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}')
debug_log(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}')
p.init_images = processed_images
elif isinstance(selected_models, list) and len(processed_images) != len(selected_models):
if is_generator:
@ -583,14 +584,14 @@ def control_run(state: str = '',
elif selected_models is not None:
p.init_images = processed_image
else:
debug('Control processed: using input direct')
debug_log('Control processed: using input direct')
processed_image = input_image
if unit_type == 'reference' and has_models:
p.ref_image = p.override or input_image
p.task_args.pop('image', None)
p.task_args['ref_image'] = p.ref_image
debug(f'Control: process=None image={p.ref_image}')
debug_log(f'Control: process=None image={p.ref_image}')
if p.ref_image is None:
if is_generator:
yield terminate('Attempting reference mode but image is none')
@ -625,7 +626,7 @@ def control_run(state: str = '',
if is_generator:
image_txt = f'{blended_image.width}x{blended_image.height}' if blended_image is not None else 'None'
msg = f'process | {index} of {frames if video is not None else len(inputs)} | {"Image" if video is None else "Frame"} {image_txt}'
debug(f'Control yield: {msg}')
debug_log(f'Control yield: {msg}')
if is_generator:
yield (None, blended_image, f'Control {msg}')
t2 += time.time() - t2
@ -684,7 +685,7 @@ def control_run(state: str = '',
if selected_scale_tab_mask == 1:
width_mask, height_mask = int(input_image.width * scale_by_mask), int(input_image.height * scale_by_mask)
p.width, p.height = width_mask, height_mask
debug(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"')
debug_log(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"')
# pipeline
output = None
@ -693,9 +694,9 @@ def control_run(state: str = '',
if not hasattr(pipe, 'restore_pipeline') and video is None:
pipe.restore_pipeline = restore_pipeline
shared.sd_model.restore_pipeline = restore_pipeline
debug(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}')
# debug(f'Control exec pipeline: p={vars(p)}')
# debug(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}')
debug_log(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}')
# debug_log(f'Control exec pipeline: p={vars(p)}')
# debug_log(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}')
if sd_models.get_diffusers_task(pipe) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: # force vae back to gpu if not in txt2img mode
sd_models.move_model(pipe.vae, devices.device)
@ -741,7 +742,7 @@ def control_run(state: str = '',
width_after = int(output_image.width * scale_by_after)
height_after = int(output_image.height * scale_by_after)
if resize_mode_after != 0 and resize_name_after != 'None' and not is_grid:
debug(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"')
debug_log(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"')
output_image = images.resize_image(resize_mode_after, output_image, width_after, height_after, resize_name_after, context=resize_context_after)
output_images.append(output_image)
@ -761,14 +762,16 @@ def control_run(state: str = '',
status, frame = video.read()
if status:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
debug(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}')
debug_log(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}')
else:
status = False
if video is not None:
video.release()
shared.log.info(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}')
debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}')
process_timer.add('init', t1-t0)
process_timer.add('proc', t2-t1)
except Exception as e:
shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}')
errors.display(e, 'Control')
@ -789,7 +792,7 @@ def control_run(state: str = '',
p.close()
restore_pipeline()
debug(f'Ready: {image_txt}')
debug_log(f'Ready: {image_txt}')
html_txt = f'<p>Ready {image_txt}</p>' if image_txt != '' else ''
if len(info_txt) > 0:

View File

@ -411,13 +411,14 @@ class ControlNetPipeline():
if dtype is not None:
self.pipeline = self.pipeline.to(dtype)
sd_models.copy_diffuser_options(self.pipeline, pipeline)
if opts.diffusers_offload_mode == 'none':
sd_models.move_model(self.pipeline, devices.device)
from modules.sd_models import set_diffuser_offload
set_diffuser_offload(self.pipeline, 'model')
t1 = time.time()
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
debug_log(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
def restore(self):
self.pipeline.unload_lora_weights()

View File

@ -27,6 +27,18 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
cache[0] = (required_prompts, steps)
return cache[1]
def check_rollback_vae():
if shared.cmd_opts.rollback_vae:
if not torch.cuda.is_available():
shared.log.error("Rollback VAE functionality requires compatible GPU")
shared.cmd_opts.rollback_vae = False
elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'):
shared.log.error("Rollback VAE functionality requires Torch 2.1 or higher")
shared.cmd_opts.rollback_vae = False
elif 0 < torch.cuda.get_device_capability()[0] < 8:
shared.log.error('Rollback VAE functionality device capabilities not met')
shared.cmd_opts.rollback_vae = False
def process_original(p: processing.StableDiffusionProcessing):
cached_uc = [None, None]
@ -42,6 +54,7 @@ def process_original(p: processing.StableDiffusionProcessing):
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
except devices.NansException as e:
check_rollback_vae()
if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae:
shared.log.warning('Tensor with all NaNs was produced in VAE')
devices.dtype_vae = torch.bfloat16

View File

@ -49,12 +49,6 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
elif (size > 20000 and size < 40000):
guess = 'FLUX'
# guess by name
"""
if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper():
if shared.backend == shared.Backend.ORIGINAL:
warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB')
guess = 'Latent Consistency Model'
"""
if 'instaflow' in f.lower():
guess = 'InstaFlow'
if 'segmoe' in f.lower():

View File

@ -1,25 +1,22 @@
import io
import sys
import time
import json
import copy
import inspect
import logging
import contextlib
import os.path
from enum import Enum
import diffusers
import diffusers.loaders.single_file_utils
from rich import progress # pylint: disable=redefined-builtin
import torch
import safetensors.torch
from omegaconf import OmegaConf
from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect
from modules.timer import Timer, process as process_timer
from modules.memstats import memory_stats
from modules.modeldata import model_data
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
from modules.sd_offload import set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import
from modules.sd_models_utils import NoWatermark, get_signature, get_call, path_to_repo, patch_diffuser_config, convert_to_faketensors, read_state_dict, get_state_dict_from_checkpoint # pylint: disable=unused-import
model_dir = "Stable-diffusion"
@ -35,165 +32,6 @@ diffusers_version = int(diffusers.__version__.split('.')[1])
checkpoint_tiles = checkpoint_titles # legacy compatibility
class NoWatermark:
def apply_watermark(self, img):
return img
def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument
if not os.path.isfile(checkpoint_file):
shared.log.error(f'Load dict: path="{checkpoint_file}" not a file')
return None
try:
pl_sd = None
with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
return None
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=buffered')
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=mmap')
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=direct')
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
except Exception as e:
errors.display(e, f'Load model: {checkpoint_file}')
sd = None
return sd
def get_state_dict_from_checkpoint(pl_sd):
checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if not os.path.isfile(checkpoint_info.filename):
return None
"""
if checkpoint_info in checkpoints_loaded:
shared.log.info("Load model: cache")
checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
return checkpoints_loaded[checkpoint_info]
"""
res = read_state_dict(checkpoint_info.filename, what='model')
"""
if shared.opts.sd_checkpoint_cache > 0 and not shared.native:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = res
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
"""
timer.record("load")
return res
def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer):
_pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model')
shared.log.debug(f'Load model: memory={memory_stats()}')
timer.record("hash")
if model_data.sd_dict == 'None':
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
try:
model.load_state_dict(state_dict, strict=False)
except Exception as e:
shared.log.error(f'Load model: path="{checkpoint_info.filename}"')
shared.log.error(' '.join(str(e).splitlines()[:2]))
return False
del state_dict
timer.record("apply")
if shared.opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("channels")
if not shared.opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.opts.upcast_sampling and depth_model:
model.depth_model = None
model.half()
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
if shared.opts.cuda_cast_unet:
devices.dtype_unet = model.model.diffusion_model.dtype
else:
model.model.diffusion_model.to(devices.dtype_unet)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = checkpoint_info.calculate_shorthash()
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.is_sdxl = False # a1111 compatibility item
model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("vae")
return True
def repair_config(sd_config):
if "use_ema" not in sd_config.model.params:
sd_config.model.params.use_ema = False
if shared.opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params:
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def change_backend():
shared.log.info(f'Backend changed: from={shared.backend} to={shared.opts.sd_backend}')
shared.log.warning('Full server restart required to apply all changes')
@ -223,21 +61,21 @@ def copy_diffuser_options(new_pipe, orig_pipe):
set_accelerate(new_pipe)
def set_vae_options(sd_model, vae = None, op: str = 'model'):
def set_vae_options(sd_model, vae=None, op:str='model', quiet:bool=False):
if hasattr(sd_model, "vae"):
if vae is not None:
sd_model.vae = vae
shared.log.debug(f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"')
shared.log.quiet(quiet, f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"')
if shared.opts.diffusers_vae_upcast != 'default':
sd_model.vae.config.force_upcast = True if shared.opts.diffusers_vae_upcast == 'true' else False
shared.log.debug(f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}')
shared.log.quiet(quiet, f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}')
if shared.opts.no_half_vae:
devices.dtype_vae = torch.float32
sd_model.vae.to(devices.dtype_vae)
shared.log.debug(f'Setting {op}: component=VAE no-half=True')
shared.log.quiet(quiet, f'Setting {op}: component=VAE no-half=True')
if hasattr(sd_model, "enable_vae_slicing"):
if shared.opts.diffusers_vae_slicing:
shared.log.debug(f'Setting {op}: component=VAE slicing=True')
shared.log.quiet(quiet, f'Setting {op}: component=VAE slicing=True')
sd_model.enable_vae_slicing()
else:
sd_model.disable_vae_slicing()
@ -249,18 +87,18 @@ def set_vae_options(sd_model, vae = None, op: str = 'model'):
sd_model.vae.tile_latent_min_size = int(sd_model.vae.config.sample_size / (2 ** (len(sd_model.vae.config.block_out_channels) - 1)))
if shared.opts.diffusers_vae_tile_overlap != 0.25:
sd_model.vae.tile_overlap_factor = float(shared.opts.diffusers_vae_tile_overlap)
shared.log.debug(f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}')
shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}')
else:
shared.log.debug(f'Setting {op}: component=VAE tiling=True')
shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True')
sd_model.enable_vae_tiling()
else:
sd_model.disable_vae_tiling()
if hasattr(sd_model, "vqvae"):
shared.log.debug(f'Setting {op}: component=VQVAE upcast=True')
shared.log.quiet(quiet, f'Setting {op}: component=VQVAE upcast=True')
sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16
def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True):
def set_diffuser_options(sd_model, vae=None, op:str='model', offload:bool=True, quiet:bool=False):
if sd_model is None:
shared.log.warning(f'{op} is not loaded')
return
@ -271,19 +109,19 @@ def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True):
sd_model.has_accelerate = False
clear_caches()
set_vae_options(sd_model, vae, op)
set_diffusers_attention(sd_model)
set_vae_options(sd_model, vae, op, quiet)
set_diffusers_attention(sd_model, quiet)
if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'):
try:
sd_model.fuse_qkv_projections()
shared.log.debug(f'Setting {op}: fused-qkv=True')
shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True')
except Exception as e:
shared.log.error(f'Setting {op}: fused-qkv=True {e}')
if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'fuse_qkv_projections'):
try:
sd_model.transformer.fuse_qkv_projections()
shared.log.debug(f'Setting {op}: fused-qkv=True')
shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True')
except Exception as e:
shared.log.error(f'Setting {op}: fused-qkv=True {e}')
if shared.opts.diffusers_eval:
@ -297,11 +135,11 @@ def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True):
sd_model = sd_models_compile.torchao_quantization(sd_model)
if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'):
shared.log.debug(f'Setting {op}: channels-last=True')
shared.log.quiet(quiet, f'Setting {op}: channels-last=True')
sd_model.unet.to(memory_format=torch.channels_last)
if offload:
set_diffuser_offload(sd_model, op)
set_diffuser_offload(sd_model, op, quiet)
def move_model(model, device=None, force=False):
@ -401,50 +239,6 @@ def move_base(model, device):
return R
def patch_diffuser_config(sd_model, model_file):
def load_config(fn, k):
model_file = os.path.splitext(fn)[0]
cfg_file = f'{model_file}_{k}.json'
try:
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json'
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
pass
return {}
if sd_model is None:
return sd_model
if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower():
if debug_load:
shared.log.debug('Model config patch: type=inpaint')
sd_model.unet.config.in_channels = 9
if not hasattr(sd_model, '_internal_dict'):
return sd_model
for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(sd_model, c, None)
if hasattr(component, 'config'):
if debug_load:
shared.log.debug(f'Model config: component={c} config={component.config}')
override = load_config(model_file, c)
updated = {}
for k, v in override.items():
if k.startswith('_'):
continue
if v != component.config.get(k, None):
if hasattr(component.config, '__frozen'):
component.config.__frozen = False # pylint: disable=protected-access
component.config[k] = v
updated[k] = v
if updated and debug_load:
shared.log.debug(f'Model config: component={c} override={updated}')
return sd_model
def load_diffuser_initial(diffusers_load_config, op='model'):
sd_model = None
checkpoint_info = None
@ -833,18 +627,6 @@ def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType:
return DiffusersTaskType.TEXT_2_IMAGE
def get_signature(cls):
signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True)
return signature.parameters
def get_call(cls):
if cls is None:
return []
signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True)
return signature.parameters
def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}):
"""
args:
@ -1071,7 +853,7 @@ def set_diffuser_pipe(pipe, new_pipe_type):
return pipe
def set_diffusers_attention(pipe):
def set_diffusers_attention(pipe, quiet:bool=False):
import diffusers.models.attention_processor as p
def set_attn(pipe, attention):
@ -1102,7 +884,7 @@ def set_diffusers_attention(pipe):
if 'ControlNet' in pipe.__class__.__name__: # do not replace attention in ControlNet pipelines
return
shared.log.debug(f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
shared.log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
if shared.opts.cross_attention_optimization == "Disabled":
pass # do nothing
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
@ -1146,103 +928,6 @@ def get_native(pipe: diffusers.DiffusionPipeline):
return size
def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'):
from ldm.util import instantiate_from_config
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint(op=op)
if checkpoint_info is None:
return
if op == 'model' or op == 'dict':
if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
return
else:
if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
return
shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}')
if timer is None:
timer = Timer()
current_checkpoint_info = None
if op == 'model' or op == 'dict':
if model_data.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
else:
if model_data.sd_refiner is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
if not shared.native:
from modules import sd_hijack_inpainting
sd_hijack_inpainting.do_inpainting_hijack()
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
if state_dict is None or checkpoint_config is None:
shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"')
if current_checkpoint_info is not None:
shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore')
load_model(current_checkpoint_info, None)
return
shared.log.debug(f'Model dict loaded: {memory_stats()}')
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("config")
shared.log.debug(f'Model config loaded: {memory_stats()}')
sd_model = None
stdout = io.StringIO()
if os.environ.get('SD_LDM_DEBUG', None) is not None:
sd_model = instantiate_from_config(sd_config.model)
else:
with contextlib.redirect_stdout(stdout):
sd_model = instantiate_from_config(sd_config.model)
for line in stdout.getvalue().splitlines():
if len(line) > 0:
shared.log.info(f'LDM: {line.strip()}')
shared.log.debug(f"Model created from config: {checkpoint_config}")
sd_model.used_config = checkpoint_config
sd_model.has_accelerate = False
timer.record("create")
ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if not ok:
model_data.sd_model = sd_model
current_checkpoint_info = None
unload_model_weights(op=op)
shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}')
if op == 'refiner':
# shared.opts.data['sd_model_refiner'] = 'None'
shared.opts.sd_model_refiner = 'None'
return
else:
shared.log.debug(f'Model weights loaded: {memory_stats()}')
timer.record("load")
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
move_model(sd_model, devices.device)
timer.record("move")
shared.log.debug(f'Model weights moved: {memory_stats()}')
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval()
if op == 'refiner':
model_data.sd_refiner = sd_model
else:
model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
timer.record("embeddings")
script_callbacks.model_loaded_callback(sd_model)
timer.record("callbacks")
shared.log.info(f"Model loaded in {timer.summary()}")
current_checkpoint_info = None
devices.torch_gc(force=True)
shared.log.info(f'Model load finished: {memory_stats()}')
def reload_text_encoder(initial=False):
if initial and (shared.opts.sd_text_encoder is None or shared.opts.sd_text_encoder == 'None'):
return # dont unload
@ -1342,35 +1027,6 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model',
return sd_model
def convert_to_faketensors(tensor):
try:
fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
if hasattr(tensor, "weight"):
tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
return tensor
except Exception:
pass
return tensor
def disable_offload(sd_model):
from accelerate.hooks import remove_hook_from_module
if not getattr(sd_model, 'has_accelerate', False):
return
if hasattr(sd_model, "_internal_dict"):
keys = sd_model._internal_dict.keys() # pylint: disable=protected-access
else:
keys = get_signature(sd_model).keys()
for module_name in keys: # pylint: disable=protected-access
module = getattr(sd_model, module_name, None)
if isinstance(module, torch.nn.Module):
network_layer_name = getattr(module, "network_layer_name", None)
module = remove_hook_from_module(module, recurse=True)
if network_layer_name:
module.network_layer_name = network_layer_name
sd_model.has_accelerate = False
def clear_caches():
# shared.log.debug('Cache clear')
if not shared.opts.lora_legacy:
@ -1411,16 +1067,3 @@ def unload_model_weights(op='model'):
model_data.sd_refiner = None
devices.torch_gc(force=True)
shared.log.debug(f'Unload weights {op}: {memory_stats()}')
def path_to_repo(fn: str = ''):
if isinstance(fn, CheckpointInfo):
fn = fn.name
repo_id = fn.replace('\\', '/')
if 'models--' in repo_id:
repo_id = repo_id.split('models--')[-1]
repo_id = repo_id.split('/')[0]
repo_id = repo_id.split('/')
repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id)
repo_id = repo_id.replace('models--', '').replace('--', '/')
return repo_id

207
modules/sd_models_legacy.py Normal file
View File

@ -0,0 +1,207 @@
import io
import os
import sys
import contextlib
from modules import shared
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def get_checkpoint_state_dict(checkpoint_info, timer):
from modules.sd_models_utils import read_state_dict
if not os.path.isfile(checkpoint_info.filename):
return None
"""
if checkpoint_info in checkpoints_loaded:
shared.log.info("Load model: cache")
checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
return checkpoints_loaded[checkpoint_info]
"""
res = read_state_dict(checkpoint_info.filename, what='model')
"""
if shared.opts.sd_checkpoint_cache > 0 and not shared.native:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = res
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
"""
timer.record("load")
return res
def repair_config(sd_config):
from modules import paths
if "use_ema" not in sd_config.model.params:
sd_config.model.params.use_ema = False
if shared.opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params:
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
def load_model_weights(model, checkpoint_info, state_dict, timer):
# _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model')
from modules.modeldata import model_data
from modules.memstats import memory_stats
from modules import devices, sd_vae
shared.log.debug(f'Load model: memory={memory_stats()}')
timer.record("hash")
if model_data.sd_dict == 'None':
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
try:
model.load_state_dict(state_dict, strict=False)
except Exception as e:
shared.log.error(f'Load model: path="{checkpoint_info.filename}"')
shared.log.error(' '.join(str(e).splitlines()[:2]))
return False
del state_dict
timer.record("apply")
if shared.opts.opt_channelslast:
import torch
model.to(memory_format=torch.channels_last)
timer.record("channels")
if not shared.opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
if shared.opts.no_half_vae: # remove VAE from model when doing half() to prevent its weights from being converted to float16
model.first_stage_model = None
if shared.opts.upcast_sampling and depth_model: # with don't convert the depth model weights to float16
model.depth_model = None
model.half()
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
if shared.opts.cuda_cast_unet:
devices.dtype_unet = model.model.diffusion_model.dtype
else:
model.model.diffusion_model.to(devices.dtype_unet)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = checkpoint_info.calculate_shorthash()
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.is_sdxl = False # a1111 compatibility item
model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("vae")
return True
def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'):
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from modules import devices, lowvram, sd_hijack, sd_models_config, script_callbacks
from modules.timer import Timer
from modules.memstats import memory_stats
from modules.modeldata import model_data
from modules.sd_models import unload_model_weights, move_model
from modules.sd_checkpoint import select_checkpoint
checkpoint_info = checkpoint_info or select_checkpoint(op=op)
if checkpoint_info is None:
return
if op == 'model' or op == 'dict':
if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
return
else:
if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
return
shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}')
if timer is None:
timer = Timer()
current_checkpoint_info = None
if op == 'model' or op == 'dict':
if model_data.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
else:
if model_data.sd_refiner is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
if not shared.native:
from modules import sd_hijack_inpainting
sd_hijack_inpainting.do_inpainting_hijack()
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
if state_dict is None or checkpoint_config is None:
shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"')
if current_checkpoint_info is not None:
shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore')
load_model(current_checkpoint_info, None)
return
shared.log.debug(f'Model dict loaded: {memory_stats()}')
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("config")
shared.log.debug(f'Model config loaded: {memory_stats()}')
sd_model = None
stdout = io.StringIO()
if os.environ.get('SD_LDM_DEBUG', None) is not None:
sd_model = instantiate_from_config(sd_config.model)
else:
with contextlib.redirect_stdout(stdout):
sd_model = instantiate_from_config(sd_config.model)
for line in stdout.getvalue().splitlines():
if len(line) > 0:
shared.log.info(f'LDM: {line.strip()}')
shared.log.debug(f"Model created from config: {checkpoint_config}")
sd_model.used_config = checkpoint_config
sd_model.has_accelerate = False
timer.record("create")
ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if not ok:
model_data.sd_model = sd_model
current_checkpoint_info = None
unload_model_weights(op=op)
shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}')
if op == 'refiner':
# shared.opts.data['sd_model_refiner'] = 'None'
shared.opts.sd_model_refiner = 'None'
return
else:
shared.log.debug(f'Model weights loaded: {memory_stats()}')
timer.record("load")
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
move_model(sd_model, devices.device)
timer.record("move")
shared.log.debug(f'Model weights moved: {memory_stats()}')
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval()
if op == 'refiner':
model_data.sd_refiner = sd_model
else:
model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
timer.record("embeddings")
script_callbacks.model_loaded_callback(sd_model)
timer.record("callbacks")
shared.log.info(f"Model loaded in {timer.summary()}")
current_checkpoint_info = None
devices.torch_gc(force=True)
shared.log.info(f'Model load finished: {memory_stats()}')

151
modules/sd_models_utils.py Normal file
View File

@ -0,0 +1,151 @@
import io
import json
import inspect
import os.path
from rich import progress # pylint: disable=redefined-builtin
import torch
import safetensors.torch
from modules import paths, shared, errors
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import
class NoWatermark:
def apply_watermark(self, img):
return img
def get_signature(cls):
signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True)
return signature.parameters
def get_call(cls):
if cls is None:
return []
signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True)
return signature.parameters
def path_to_repo(fn: str = ''):
if isinstance(fn, CheckpointInfo):
fn = fn.name
repo_id = fn.replace('\\', '/')
if 'models--' in repo_id:
repo_id = repo_id.split('models--')[-1]
repo_id = repo_id.split('/')[0]
repo_id = repo_id.split('/')
repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id)
repo_id = repo_id.replace('models--', '').replace('--', '/')
return repo_id
def convert_to_faketensors(tensor):
try:
fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
if hasattr(tensor, "weight"):
tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
return tensor
except Exception:
pass
return tensor
def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument
if not os.path.isfile(checkpoint_file):
shared.log.error(f'Load dict: path="{checkpoint_file}" not a file')
return None
try:
pl_sd = None
with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
return None
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=buffered')
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=mmap')
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=direct')
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
except Exception as e:
errors.display(e, f'Load model: {checkpoint_file}')
sd = None
return sd
def get_state_dict_from_checkpoint(pl_sd):
checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd
def patch_diffuser_config(sd_model, model_file):
def load_config(fn, k):
model_file = os.path.splitext(fn)[0]
cfg_file = f'{model_file}_{k}.json'
try:
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json'
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
pass
return {}
if sd_model is None:
return sd_model
if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower():
sd_model.unet.config.in_channels = 9
if not hasattr(sd_model, '_internal_dict'):
return sd_model
for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(sd_model, c, None)
if hasattr(component, 'config'):
override = load_config(model_file, c)
updated = {}
for k, v in override.items():
if k.startswith('_'):
continue
if v != component.config.get(k, None):
if hasattr(component.config, '__frozen'):
component.config.__frozen = False # pylint: disable=protected-access
component.config[k] = v
updated[k] = v
return sd_model

View File

@ -4,6 +4,7 @@ import time
import inspect
import torch
import accelerate
from modules import shared, devices, errors
from modules.timer import process as process_timer
@ -18,6 +19,24 @@ def get_signature(cls):
return signature.parameters
def disable_offload(sd_model):
from accelerate.hooks import remove_hook_from_module
if not getattr(sd_model, 'has_accelerate', False):
return
if hasattr(sd_model, "_internal_dict"):
keys = sd_model._internal_dict.keys() # pylint: disable=protected-access
else:
keys = get_signature(sd_model).keys()
for module_name in keys: # pylint: disable=protected-access
module = getattr(sd_model, module_name, None)
if isinstance(module, torch.nn.Module):
network_layer_name = getattr(module, "network_layer_name", None)
module = remove_hook_from_module(module, recurse=True)
if network_layer_name:
module.network_layer_name = network_layer_name
sd_model.has_accelerate = False
def set_accelerate(sd_model):
def set_accelerate_to_module(model):
if hasattr(model, "pipe"):
@ -36,7 +55,7 @@ def set_accelerate(sd_model):
set_accelerate_to_module(sd_model.decoder_pipe)
def set_diffuser_offload(sd_model, op: str = 'model'):
def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False):
t0 = time.time()
if not shared.native:
shared.log.warning('Attempting to use offload with backend=original')
@ -50,13 +69,13 @@ def set_diffuser_offload(sd_model, op: str = 'model'):
if shared.sd_model_type in should_offload:
shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model')
else:
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if hasattr(sd_model, 'maybe_free_model_hooks'):
sd_model.maybe_free_model_hooks()
sd_model.has_accelerate = False
if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"):
try:
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False

View File

@ -211,14 +211,12 @@ def writefile(data, filename, mode='w', silent=False, atomic=False):
# early select backend
default_backend = 'diffusers'
early_opts = readfile(cmd_opts.config, silent=True)
early_backend = early_opts.get('sd_backend', default_backend)
backend = Backend.DIFFUSERS if early_backend.lower() == 'diffusers' else Backend.ORIGINAL
early_backend = early_opts.get('sd_backend', 'diffusers')
backend = Backend.ORIGINAL if early_backend.lower() == 'original' else Backend.DIFFUSERS
if cmd_opts.backend is not None: # override with args
backend = Backend.DIFFUSERS if cmd_opts.backend.lower() == 'diffusers' else Backend.ORIGINAL
backend = Backend.ORIGINAL if cmd_opts.backend.lower() == 'original' else Backend.DIFFUSERS
if cmd_opts.use_openvino: # override for openvino
backend = Backend.DIFFUSERS
from modules.intel.openvino import get_device_list as get_openvino_device_list # pylint: disable=ungrouped-imports
elif cmd_opts.use_ipex or devices.has_xpu():
from modules.intel.ipex import ipex_init
@ -226,15 +224,14 @@ elif cmd_opts.use_ipex or devices.has_xpu():
if not ok:
log.error(f'IPEX initialization failed: {e}')
elif cmd_opts.use_directml:
name = 'directml'
from modules.dml import directml_init
ok, e = directml_init()
if not ok:
log.error(f'DirectML initialization failed: {e}')
devices.backend = devices.get_backend(cmd_opts)
devices.device = devices.get_optimal_device()
cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2)
mem_stat = memory_stats()
cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2)
gpu_memory = mem_stat['gpu']['total'] if "gpu" in mem_stat else 0
native = backend == Backend.DIFFUSERS
if not files_cache.do_cache_folders:
@ -475,7 +472,7 @@ def get_default_modes():
startup_offload_mode, startup_cross_attention, startup_sdp_options = get_default_modes()
options_templates.update(options_section(('sd', "Models & Loading"), {
"sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }),
"sd_backend": OptionInfo('diffusers', "Execution backend", gr.Radio, {"choices": ['diffusers', 'original'] }),
"diffusers_pipeline": OptionInfo('Autodetect', 'Model pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()), "visible": native}),
"sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints),
"sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
@ -513,7 +510,7 @@ options_templates.update(options_section(('vae_encoder', "Variable Auto Encoder"
"diffusers_vae_tile_overlap": OptionInfo(0.25, "VAE tile overlap", gr.Slider, {"minimum": 0, "maximum": 0.95, "step": 0.05 }),
"sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}),
"nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox),
"rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"),
"rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values", gr.Checkbox, {"visible": not native}),
}))
options_templates.update(options_section(('text_encoder', "Text Encoder"), {

View File

@ -6,8 +6,8 @@ from modules import shared, scripts, masking # pylint: disable=ungrouped-imports
gr_height = None
max_units = shared.opts.control_max_units
debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None
debug_log = shared.log.trace if debug else lambda *args, **kwargs: None
# state variables
busy = False # used to synchronize select_input and generate_click
@ -127,7 +127,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i
busy = False
# debug('Control input: none')
return [gr.Tabs.update(), None, '']
debug(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}')
debug_log(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}')
input_type = type(selected_input)
input_mask = None
status = 'Control input | Unknown'
@ -168,7 +168,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i
res = [gr.Tabs.update(selected='out-gallery'), input_mask, status]
else: # unknown
input_source = None
shared.log.debug(f'Control input: type={input_type} input={input_source}')
debug_log(f'Control input: type={input_type} input={input_source}')
# init inputs: optional
if init_type == 0: # Control only
input_init = None
@ -176,7 +176,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i
input_init = None
elif init_type == 2: # Separate init image
input_init = [init_image]
debug(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}')
debug_log(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}')
busy = False
return res
@ -191,7 +191,7 @@ def video_type_change(video_type):
def copy_input(mode_from, mode_to, input_image, input_resize, input_inpaint):
debug(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}')
debug_log(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}')
def getimg(ctrl):
if ctrl is None:
return None

View File

@ -3,10 +3,8 @@ import os
import time
import contextlib
import gradio as gr
import numpy as np
from PIL import Image
from modules import shared, devices, errors, scripts, processing, processing_helpers, sd_models
from modules.api.api import decode_base64_to_image
debug = os.environ.get('SD_PULID_DEBUG', None) is not None
@ -59,12 +57,16 @@ class Script(scripts.Script):
xyz_classes.axis_options.append(option)
def decode_image(self, b64):
from modules.api.api import decode_base64_to_image
return decode_base64_to_image(b64)
def load_images(self, files):
uploaded_images.clear()
for file in files or []:
try:
if isinstance(file, str):
image = decode_base64_to_image(file)
image = self.decode_image(file)
elif isinstance(file, Image.Image):
image = file
elif isinstance(file, dict) and 'name' in file:
@ -113,16 +115,17 @@ class Script(scripts.Script):
version: str = 'v1.1'
): # pylint: disable=arguments-differ, unused-argument
images = []
import numpy as np
try:
if gallery is None or (isinstance(gallery, list) and len(gallery) == 0):
images = getattr(p, 'pulid_images', uploaded_images)
images = [decode_base64_to_image(image) if isinstance(image, str) else image for image in images]
images = [self.decode_image(image) if isinstance(image, str) else image for image in images]
elif isinstance(gallery[0], dict):
images = [Image.open(f['name']) for f in gallery]
elif isinstance(gallery, str):
images = [decode_base64_to_image(gallery)]
images = [self.decode_image(gallery)]
elif isinstance(gallery[0], str):
images = [decode_base64_to_image(f) for f in gallery]
images = [self.decode_image(f) for f in gallery]
else:
images = gallery
images = [np.array(image) for image in images]

127
webui.py
View File

@ -11,13 +11,9 @@ import contextlib
from threading import Thread
import modules.hashes
import modules.loader
import torch # pylint: disable=wrong-import-order
from modules import timer, errors, paths # pylint: disable=unused-import
from installer import log, git_commit, custom_excepthook
# import ldm.modules.encoders.modules # pylint: disable=unused-import, wrong-import-order
from modules import shared, extensions, gr_tempdir, modelloader # pylint: disable=ungrouped-imports
from modules import extra_networks, ui_extra_networks # pylint: disable=ungrouped-imports
from modules.paths import create_paths
from modules import timer, paths, shared, extensions, gr_tempdir, modelloader
from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import
import modules.devices
import modules.sd_checkpoint
@ -33,23 +29,28 @@ import modules.ui
import modules.txt2img
import modules.img2img
import modules.upscaler
import modules.extra_networks
import modules.ui_extra_networks
import modules.textual_inversion.textual_inversion
import modules.hypernetworks.hypernetwork
import modules.script_callbacks
from modules.api.middleware import setup_middleware
from modules.shared import cmd_opts, opts # pylint: disable=unused-import
import modules.api.middleware
if not modules.loader.initialized:
timer.startup.record("libraries")
import modules.sd_hijack # runs conditional load of ldm if not shared.native
timer.startup.record("ldm")
modules.loader.initialized = True
sys.excepthook = custom_excepthook
local_url = None
state = shared.state
backend = shared.backend
if not modules.loader.initialized:
timer.startup.record("libraries")
if cmd_opts.server_name:
server_name = cmd_opts.server_name
if shared.cmd_opts.server_name:
server_name = shared.cmd_opts.server_name
else:
server_name = "0.0.0.0" if cmd_opts.listen else None
server_name = "0.0.0.0" if shared.cmd_opts.listen else None
fastapi_args = {
"version": f'0.0.{git_commit}',
"title": "SD.Next",
@ -60,30 +61,12 @@ fastapi_args = {
# "redoc_url": "/redocs" if cmd_opts.docs else None,
}
import modules.sd_hijack
timer.startup.record("ldm")
modules.loader.initialized = True
def check_rollback_vae():
if shared.cmd_opts.rollback_vae:
if not torch.cuda.is_available():
log.error("Rollback VAE functionality requires compatible GPU")
shared.cmd_opts.rollback_vae = False
elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'):
log.error("Rollback VAE functionality requires Torch 2.1 or higher")
shared.cmd_opts.rollback_vae = False
elif 0 < torch.cuda.get_device_capability()[0] < 8:
log.error('Rollback VAE functionality device capabilities not met')
shared.cmd_opts.rollback_vae = False
def initialize():
log.debug('Initializing')
modules.sd_checkpoint.init_metadata()
modules.hashes.init_cache()
check_rollback_vae()
log.debug(f'Huggingface cache: path="{shared.opts.hfcache_dir}"')
@ -136,20 +119,20 @@ def initialize():
shared.reload_hypernetworks()
timer.startup.record("hypernetworks")
ui_extra_networks.initialize()
ui_extra_networks.register_pages()
extra_networks.initialize()
extra_networks.register_default_extra_networks()
modules.ui_extra_networks.initialize()
modules.ui_extra_networks.register_pages()
modules.extra_networks.initialize()
modules.extra_networks.register_default_extra_networks()
timer.startup.record("networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_certfile is not None:
if shared.cmd_opts.tls_keyfile is not None and shared.cmd_opts.tls_certfile is not None:
try:
if not os.path.exists(cmd_opts.tls_keyfile):
if not os.path.exists(shared.cmd_opts.tls_keyfile):
log.error("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
log.error(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
if not os.path.exists(shared.cmd_opts.tls_certfile):
log.error(f"Invalid path to TLS certfile: '{shared.cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
shared.cmd_opts.tls_keyfile = shared.cmd_opts.tls_certfile = None
log.error("TLS setup invalid, running webui without TLS")
else:
log.info("Running with TLS")
@ -231,7 +214,7 @@ def start_common():
log.info(f'Using data path: {shared.cmd_opts.data_dir}')
if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models':
log.info(f'Models path: {shared.cmd_opts.models_dir}')
create_paths(shared.opts)
paths.create_paths(shared.opts)
async_policy()
initialize()
try:
@ -251,20 +234,20 @@ def start_ui():
timer.startup.record("before-ui")
shared.demo = modules.ui.create_ui(timer.startup)
timer.startup.record("ui")
if cmd_opts.disable_queue:
if shared.cmd_opts.disable_queue:
log.info('Server queues disabled')
shared.demo.progress_tracking = False
else:
shared.demo.queue(concurrency_count=64)
gradio_auth_creds = []
if cmd_opts.auth:
gradio_auth_creds += [x.strip() for x in cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.auth_file:
if not os.path.exists(cmd_opts.auth_file):
log.error(f"Invalid path to auth file: '{cmd_opts.auth_file}'")
if shared.cmd_opts.auth:
gradio_auth_creds += [x.strip() for x in shared.cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()]
if shared.cmd_opts.auth_file:
if not os.path.exists(shared.cmd_opts.auth_file):
log.error(f"Invalid path to auth file: '{shared.cmd_opts.auth_file}'")
else:
with open(cmd_opts.auth_file, 'r', encoding="utf8") as file:
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
if len(gradio_auth_creds) > 0:
@ -273,19 +256,19 @@ def start_ui():
global local_url # pylint: disable=global-statement
stdout = io.StringIO()
allowed_paths = [os.path.dirname(__file__)]
if cmd_opts.data_dir is not None and os.path.isdir(cmd_opts.data_dir):
allowed_paths.append(cmd_opts.data_dir)
if cmd_opts.allowed_paths is not None:
allowed_paths += [p for p in cmd_opts.allowed_paths if os.path.isdir(p)]
if shared.cmd_opts.data_dir is not None and os.path.isdir(shared.cmd_opts.data_dir):
allowed_paths.append(shared.cmd_opts.data_dir)
if shared.cmd_opts.allowed_paths is not None:
allowed_paths += [p for p in shared.cmd_opts.allowed_paths if os.path.isdir(p)]
shared.log.debug(f'Root paths: {allowed_paths}')
with contextlib.redirect_stdout(stdout):
app, local_url, share_url = shared.demo.launch( # app is FastAPI(Starlette) instance
share=cmd_opts.share,
share=shared.cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port if cmd_opts.port != 7860 else None,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=not cmd_opts.tls_selfsign,
server_port=shared.cmd_opts.port if shared.cmd_opts.port != 7860 else None,
ssl_keyfile=shared.cmd_opts.tls_keyfile,
ssl_certfile=shared.cmd_opts.tls_certfile,
ssl_verify=not shared.cmd_opts.tls_selfsign,
debug=False,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
prevent_thread_lock=True,
@ -295,24 +278,24 @@ def start_ui():
favicon_path='html/favicon.svg',
allowed_paths=allowed_paths,
app_kwargs=fastapi_args,
_frontend=True and cmd_opts.share,
_frontend=True and shared.cmd_opts.share,
)
if cmd_opts.data_dir is not None:
gr_tempdir.register_tmp_file(shared.demo, os.path.join(cmd_opts.data_dir, 'x'))
if shared.cmd_opts.data_dir is not None:
gr_tempdir.register_tmp_file(shared.demo, os.path.join(shared.cmd_opts.data_dir, 'x'))
shared.log.info(f'Local URL: {local_url}')
if cmd_opts.docs:
if shared.cmd_opts.docs:
shared.log.info(f'API Docs: {local_url[:-1]}/docs') # pylint: disable=unsubscriptable-object
shared.log.info(f'API ReDocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object
if share_url is not None:
shared.log.info(f'Share URL: {share_url}')
# shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}')
shared.demo.server.wants_restart = False
setup_middleware(app, cmd_opts)
modules.api.middleware.setup_middleware(app, shared.cmd_opts)
if cmd_opts.subpath:
if shared.cmd_opts.subpath:
import gradio
gradio.mount_gradio_app(app, shared.demo, path=f"/{cmd_opts.subpath}")
shared.log.info(f'Redirector mounted: /{cmd_opts.subpath}')
gradio.mount_gradio_app(app, shared.demo, path=f"/{shared.cmd_opts.subpath}")
shared.log.info(f'Redirector mounted: /{shared.cmd_opts.subpath}')
timer.startup.record("launch")
@ -320,7 +303,7 @@ def start_ui():
shared.api = create_api(app)
timer.startup.record("api")
ui_extra_networks.init_api(app)
modules.ui_extra_networks.init_api(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
timer.startup.record("app-started")
@ -345,7 +328,7 @@ def webui(restart=False):
modules.sd_models.write_metadata()
load_model()
shared.opts.save(shared.config_filename)
if cmd_opts.profile:
if shared.cmd_opts.profile:
for k, v in modules.script_callbacks.callback_map.items():
shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}')
debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
@ -357,7 +340,7 @@ def webui(restart=False):
debug(f' {m}')
modules.script_callbacks.print_timers()
if cmd_opts.profile:
if shared.cmd_opts.profile:
log.info(f"Launch time: {timer.launch.summary(min_time=0)}")
log.info(f"Installer time: {timer.init.summary(min_time=0)}")
log.info(f"Startup time: {timer.startup.summary(min_time=0)}")
@ -374,8 +357,8 @@ def webui(restart=False):
continue
logger.handlers = log.handlers
# autolaunch only on initial start
if (shared.opts.autolaunch or cmd_opts.autolaunch) and local_url is not None:
cmd_opts.autolaunch = False
if (shared.opts.autolaunch or shared.cmd_opts.autolaunch) and local_url is not None:
shared.cmd_opts.autolaunch = False
shared.log.info('Launching browser')
import webbrowser
webbrowser.open(local_url, new=2, autoraise=True)
@ -390,7 +373,7 @@ def api_only():
start_common()
from fastapi import FastAPI
app = FastAPI(**fastapi_args)
setup_middleware(app, cmd_opts)
modules.api.middleware.setup_middleware(app, shared.cmd_opts)
shared.api = create_api(app)
shared.api.wants_restart = False
modules.script_callbacks.app_started_callback(None, app)
@ -401,7 +384,7 @@ def api_only():
if __name__ == "__main__":
if cmd_opts.api_only:
if shared.cmd_opts.api_only:
api_only()
else:
webui()

2
wiki

@ -1 +1 @@
Subproject commit 7c2400e9dc5dee3c52eac6bbfa88352f7815454a
Subproject commit 29e37ad766904bc04f9e9701c2503a3f0898964a