automatic/modules/sd_models.py

1500 lines
75 KiB
Python

import io
import sys
import time
import json
import copy
import inspect
import logging
import contextlib
import os.path
from enum import Enum
import diffusers
import diffusers.loaders.single_file_utils
from rich import progress # pylint: disable=redefined-builtin
import torch
import safetensors.torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect
from modules.timer import Timer, process as process_timer
from modules.memstats import memory_stats
from modules.modeldata import model_data
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
sd_metadata = None
sd_metadata_pending = 0
sd_metadata_timer = 0
debug_move = shared.log.trace if os.environ.get('SD_MOVE_DEBUG', None) is not None else lambda *args, **kwargs: None
debug_load = os.environ.get('SD_LOAD_DEBUG', None)
debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
diffusers_version = int(diffusers.__version__.split('.')[1])
checkpoint_tiles = checkpoint_titles # legacy compatibility
class NoWatermark:
def apply_watermark(self, img):
return img
def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument
if not os.path.isfile(checkpoint_file):
shared.log.error(f'Load dict: path="{checkpoint_file}" not a file')
return None
try:
pl_sd = None
with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
return None
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=buffered')
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
# shared.log.debug('Model weights loading: type=safetensors mode=mmap')
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
# shared.log.debug('Model weights loading: type=checkpoint mode=direct')
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
except Exception as e:
errors.display(e, f'Load model: {checkpoint_file}')
sd = None
return sd
def get_state_dict_from_checkpoint(pl_sd):
checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if not os.path.isfile(checkpoint_info.filename):
return None
"""
if checkpoint_info in checkpoints_loaded:
shared.log.info("Load model: cache")
checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
return checkpoints_loaded[checkpoint_info]
"""
res = read_state_dict(checkpoint_info.filename, what='model')
"""
if shared.opts.sd_checkpoint_cache > 0 and not shared.native:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = res
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
"""
timer.record("load")
return res
def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer):
_pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model')
shared.log.debug(f'Load model: memory={memory_stats()}')
timer.record("hash")
if model_data.sd_dict == 'None':
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
try:
model.load_state_dict(state_dict, strict=False)
except Exception as e:
shared.log.error(f'Load model: path="{checkpoint_info.filename}"')
shared.log.error(' '.join(str(e).splitlines()[:2]))
return False
del state_dict
timer.record("apply")
if shared.opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("channels")
if not shared.opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.opts.upcast_sampling and depth_model:
model.depth_model = None
model.half()
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
if shared.opts.cuda_cast_unet:
devices.dtype_unet = model.model.diffusion_model.dtype
else:
model.model.diffusion_model.to(devices.dtype_unet)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = checkpoint_info.calculate_shorthash()
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.is_sdxl = False # a1111 compatibility item
model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("vae")
return True
def repair_config(sd_config):
if "use_ema" not in sd_config.model.params:
sd_config.model.params.use_ema = False
if shared.opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params:
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def change_backend():
shared.log.info(f'Backend changed: from={shared.backend} to={shared.opts.sd_backend}')
shared.log.warning('Full server restart required to apply all changes')
unload_model_weights()
shared.backend = shared.Backend.ORIGINAL if shared.opts.sd_backend == 'original' else shared.Backend.DIFFUSERS
shared.native = shared.backend == shared.Backend.DIFFUSERS
from modules.sd_samplers import list_samplers
list_samplers()
list_models()
from modules.sd_vae import refresh_vae_list
refresh_vae_list()
def copy_diffuser_options(new_pipe, orig_pipe):
new_pipe.sd_checkpoint_info = getattr(orig_pipe, 'sd_checkpoint_info', None)
new_pipe.sd_model_checkpoint = getattr(orig_pipe, 'sd_model_checkpoint', None)
new_pipe.embedding_db = getattr(orig_pipe, 'embedding_db', None)
new_pipe.sd_model_hash = getattr(orig_pipe, 'sd_model_hash', None)
new_pipe.has_accelerate = getattr(orig_pipe, 'has_accelerate', False)
new_pipe.current_attn_name = getattr(orig_pipe, 'current_attn_name', None)
new_pipe.default_scheduler = getattr(orig_pipe, 'default_scheduler', None)
new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item
new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False)
new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True)
if new_pipe.has_accelerate:
set_accelerate(new_pipe)
def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True):
if sd_model is None:
shared.log.warning(f'{op} is not loaded')
return
if hasattr(sd_model, "watermark"):
sd_model.watermark = NoWatermark()
if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate):
sd_model.has_accelerate = False
if hasattr(sd_model, "vae"):
if vae is not None:
sd_model.vae = vae
shared.log.debug(f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"')
if shared.opts.diffusers_vae_upcast != 'default':
sd_model.vae.config.force_upcast = True if shared.opts.diffusers_vae_upcast == 'true' else False
shared.log.debug(f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}')
if shared.opts.no_half_vae:
devices.dtype_vae = torch.float32
sd_model.vae.to(devices.dtype_vae)
shared.log.debug(f'Setting {op}: component=VAE no-half=True')
if hasattr(sd_model, "enable_vae_slicing"):
if shared.opts.diffusers_vae_slicing:
shared.log.debug(f'Setting {op}: component=VAE slicing=True')
sd_model.enable_vae_slicing()
else:
sd_model.disable_vae_slicing()
if hasattr(sd_model, "enable_vae_tiling"):
if shared.opts.diffusers_vae_tiling:
shared.log.debug(f'Setting {op}: component=VAE tiling=True')
sd_model.enable_vae_tiling()
else:
sd_model.disable_vae_tiling()
if hasattr(sd_model, "vqvae"):
shared.log.debug(f'Setting {op}: component=VQVAE upcast=True')
sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16
set_diffusers_attention(sd_model)
if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'):
try:
sd_model.fuse_qkv_projections()
shared.log.debug(f'Setting {op}: fused-qkv=True')
except Exception as e:
shared.log.error(f'Setting {op}: fused-qkv=True {e}')
if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'fuse_qkv_projections'):
try:
sd_model.transformer.fuse_qkv_projections()
shared.log.debug(f'Setting {op}: fused-qkv=True')
except Exception as e:
shared.log.error(f'Setting {op}: fused-qkv=True {e}')
if shared.opts.diffusers_eval:
def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
if hasattr(model, "requires_grad_"):
model.requires_grad_(False)
model.eval()
return model
sd_model = sd_models_compile.apply_compile_to_model(sd_model, eval_model, ["Model", "VAE", "Text Encoder"], op="eval")
if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode != 'post':
sd_model = sd_models_compile.torchao_quantization(sd_model)
if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'):
shared.log.debug(f'Setting {op}: channels-last=True')
sd_model.unet.to(memory_format=torch.channels_last)
if offload:
set_diffuser_offload(sd_model, op)
def set_accelerate_to_module(model):
if hasattr(model, "pipe"):
set_accelerate_to_module(model.pipe)
if hasattr(model, "_internal_dict"):
for k in model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(model, k, None)
if isinstance(component, torch.nn.Module):
component.has_accelerate = True
def set_accelerate(sd_model):
sd_model.has_accelerate = True
set_accelerate_to_module(sd_model)
if hasattr(sd_model, "prior_pipe"):
set_accelerate_to_module(sd_model.prior_pipe)
if hasattr(sd_model, "decoder_pipe"):
set_accelerate_to_module(sd_model.decoder_pipe)
def set_diffuser_offload(sd_model, op: str = 'model'):
if not shared.native:
shared.log.warning('Attempting to use offload with backend=original')
return
if sd_model is None:
shared.log.warning(f'{op} is not loaded')
return
if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate):
sd_model.has_accelerate = False
if hasattr(sd_model, 'maybe_free_model_hooks') and shared.opts.diffusers_offload_mode == "none":
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
sd_model.maybe_free_model_hooks()
sd_model.has_accelerate = False
if hasattr(sd_model, "enable_model_cpu_offload") and shared.opts.diffusers_offload_mode == "model":
try:
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
shared.opts.diffusers_move_refiner = False
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled')
if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access
sd_model.enable_model_cpu_offload(device=devices.device)
else:
sd_model.maybe_free_model_hooks()
set_accelerate(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
if hasattr(sd_model, "enable_sequential_cpu_offload") and shared.opts.diffusers_offload_mode == "sequential":
try:
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
shared.opts.diffusers_move_refiner = False
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled')
if sd_model.has_accelerate:
if op == "vae": # reapply sequential offload to vae
from accelerate import cpu_offload
sd_model.vae.to("cpu")
cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access
else:
pass # do nothing if offload is already applied
else:
sd_model.enable_sequential_cpu_offload(device=devices.device)
set_accelerate(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
if shared.opts.diffusers_offload_mode == "balanced":
try:
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} threshold={shared.opts.diffusers_offload_max_gpu_memory} limit={shared.opts.cuda_mem_fraction}')
sd_model = apply_balanced_offload(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
def apply_balanced_offload(sd_model):
from accelerate import infer_auto_device_map, dispatch_model
from accelerate.hooks import add_hook_to_module, remove_hook_from_module, ModelHook
excluded = ['OmniGenPipeline']
if sd_model.__class__.__name__ in excluded:
return sd_model
class dispatch_from_cpu_hook(ModelHook):
def init_hook(self, module):
return module
def pre_forward(self, module, *args, **kwargs):
if devices.normalize_device(module.device) != devices.normalize_device(devices.device):
device_index = torch.device(devices.device).index
if device_index is None:
device_index = 0
max_memory = {
device_index: f"{shared.opts.diffusers_offload_max_gpu_memory}GiB",
"cpu": f"{shared.opts.diffusers_offload_max_cpu_memory}GiB",
}
device_map = infer_auto_device_map(module, max_memory=max_memory)
module = remove_hook_from_module(module, recurse=True)
offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
module = dispatch_model(module, device_map=device_map, offload_dir=offload_dir)
module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True)
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
return args, kwargs
def post_forward(self, module, output):
return output
def detach_hook(self, module):
return module
def apply_balanced_offload_to_module(pipe):
if hasattr(pipe, "pipe"):
apply_balanced_offload_to_module(pipe.pipe)
if hasattr(pipe, "_internal_dict"):
keys = pipe._internal_dict.keys() # pylint: disable=protected-access
else:
keys = get_signature(shared.sd_model).keys()
for module_name in keys: # pylint: disable=protected-access
module = getattr(pipe, module_name, None)
if isinstance(module, torch.nn.Module):
checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None
if checkpoint_name is None:
checkpoint_name = pipe.__class__.__name__
offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name)
module = remove_hook_from_module(module, recurse=True)
try:
module = module.to("cpu")
module.offload_dir = offload_dir
network_layer_name = getattr(module, "network_layer_name", None)
module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True)
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
if network_layer_name:
module.network_layer_name = network_layer_name
except Exception as e:
if 'bitsandbytes' not in str(e):
shared.log.error(f'Balanced offload: module={module_name} {e}')
devices.torch_gc(fast=True)
apply_balanced_offload_to_module(sd_model)
if hasattr(sd_model, "pipe"):
apply_balanced_offload_to_module(sd_model.pipe)
if hasattr(sd_model, "prior_pipe"):
apply_balanced_offload_to_module(sd_model.prior_pipe)
if hasattr(sd_model, "decoder_pipe"):
apply_balanced_offload_to_module(sd_model.decoder_pipe)
set_accelerate(sd_model)
return sd_model
def move_model(model, device=None, force=False):
if model is None or device is None:
return
if not shared.native:
if type(model).__name__ == 'LatentDiffusion':
model = model.to(device)
if hasattr(model, 'model'):
model.model = model.model.to(device)
if hasattr(model, 'first_stage_model'):
model.first_stage_model = model.first_stage_model.to(device)
if hasattr(model, 'cond_stage_model'):
model.cond_stage_model = model.cond_stage_model.to(device)
devices.torch_gc()
return
if hasattr(model, 'pipe'):
move_model(model.pipe, device, force)
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
if getattr(model, 'vae', None) is not None and get_diffusers_task(model) != DiffusersTaskType.TEXT_2_IMAGE:
if device == devices.device and model.vae.device.type != "meta": # force vae back to gpu if not in txt2img mode
model.vae.to(device)
if hasattr(model.vae, '_hf_hook'):
debug_move(f'Model move: to={device} class={model.vae.__class__} fn={fn}') # pylint: disable=protected-access
model.vae._hf_hook.execution_device = device # pylint: disable=protected-access
if hasattr(model, "components"): # accelerate patch
for name, m in model.components.items():
if not hasattr(m, "_hf_hook"): # not accelerate hook
break
if not isinstance(m, torch.nn.Module) or name in model._exclude_from_cpu_offload: # pylint: disable=protected-access
continue
for module in m.modules():
if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None): # pylint: disable=protected-access
try:
module._hf_hook.execution_device = device # pylint: disable=protected-access
except Exception as e:
if os.environ.get('SD_MOVE_DEBUG', None):
shared.log.error(f'Model move execution device: device={device} {e}')
if getattr(model, 'has_accelerate', False) and not force:
return
if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device):
return
try:
t0 = time.time()
try:
if hasattr(model, 'to'):
model.to(device)
if hasattr(model, "prior_pipe"):
model.prior_pipe.to(device)
except Exception as e0:
if 'Cannot copy out of meta tensor' in str(e0) or 'must be Tensor, not NoneType' in str(e0):
if hasattr(model, "components"):
for _name, component in model.components.items():
if hasattr(component, 'modules'):
for module in component.modules():
try:
if hasattr(module, 'to'):
module.to(device)
except Exception as e2:
if 'Cannot copy out of meta tensor' in str(e2):
if os.environ.get('SD_MOVE_DEBUG', None):
shared.log.warning(f'Model move meta: module={module.__class__}')
module.to_empty(device=device)
elif 'enable_sequential_cpu_offload' in str(e0):
pass # ignore model move if sequential offload is enabled
elif 'Params4bit' in str(e0) or 'Params8bit' in str(e0):
pass # ignore model move if quantization is enabled
else:
raise e0
t1 = time.time()
except Exception as e1:
t1 = time.time()
shared.log.error(f'Model move: device={device} {e1}')
if 'move' not in process_timer.records:
process_timer.records['move'] = 0
process_timer.records['move'] += t1 - t0
if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 1:
shared.log.debug(f'Model move: device={device} class={model.__class__.__name__} accelerate={getattr(model, "has_accelerate", False)} fn={fn} time={t1-t0:.2f}') # pylint: disable=protected-access
devices.torch_gc()
def move_base(model, device):
if hasattr(model, 'transformer'):
key = 'transformer'
elif hasattr(model, 'unet'):
key = 'unet'
else:
shared.log.warning(f'Model move: model={model.__class__} device={device} key=unknown')
return None
shared.log.debug(f'Model move: module={key} device={device}')
model = getattr(model, key)
R = model.device
move_model(model, device)
return R
def patch_diffuser_config(sd_model, model_file):
def load_config(fn, k):
model_file = os.path.splitext(fn)[0]
cfg_file = f'{model_file}_{k}.json'
try:
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json'
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
pass
return {}
if sd_model is None:
return sd_model
if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower():
if debug_load:
shared.log.debug('Model config patch: type=inpaint')
sd_model.unet.config.in_channels = 9
if not hasattr(sd_model, '_internal_dict'):
return sd_model
for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(sd_model, c, None)
if hasattr(component, 'config'):
if debug_load:
shared.log.debug(f'Model config: component={c} config={component.config}')
override = load_config(model_file, c)
updated = {}
for k, v in override.items():
if k.startswith('_'):
continue
if v != component.config.get(k, None):
if hasattr(component.config, '__frozen'):
component.config.__frozen = False # pylint: disable=protected-access
component.config[k] = v
updated[k] = v
if updated and debug_load:
shared.log.debug(f'Model config: component={c} override={updated}')
return sd_model
def load_diffuser_initial(diffusers_load_config, op='model'):
sd_model = None
checkpoint_info = None
ckpt_basename = os.path.basename(shared.cmd_opts.ckpt)
model_name = modelloader.find_diffuser(ckpt_basename)
if model_name is not None:
shared.log.info(f'Load model {op}: path="{model_name}"')
model_file = modelloader.download_diffusers_model(hub_id=model_name, variant=diffusers_load_config.get('variant', None))
try:
shared.log.debug(f'Load {op}: config={diffusers_load_config}')
sd_model = diffusers.DiffusionPipeline.from_pretrained(model_file, **diffusers_load_config)
except Exception as e:
shared.log.error(f'Failed loading model: {model_file} {e}')
errors.display(e, f'Load {op}: path="{model_file}"')
return None, None
list_models() # rescan for downloaded model
checkpoint_info = CheckpointInfo(model_name)
return sd_model, checkpoint_info
def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='model'):
sd_model = None
try:
if model_type in ['Stable Cascade']: # forced pipeline
from modules.model_stablecascade import load_cascade_combined
sd_model = load_cascade_combined(checkpoint_info, diffusers_load_config)
elif model_type in ['InstaFlow']: # forced pipeline
pipeline = diffusers.utils.get_class_from_dynamic_module('instaflow_one_step', module_file='pipeline.py')
sd_model = pipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
elif model_type in ['SegMoE']: # forced pipeline
from modules.segmoe.segmoe_model import SegMoEPipeline
sd_model = SegMoEPipeline(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model = sd_model.pipe # segmoe pipe does its stuff in __init__ and __call__ is the original pipeline
elif model_type in ['PixArt-Sigma']: # forced pipeline
from modules.model_pixart import load_pixart
sd_model = load_pixart(checkpoint_info, diffusers_load_config)
elif model_type in ['Lumina-Next']: # forced pipeline
from modules.model_lumina import load_lumina
sd_model = load_lumina(checkpoint_info, diffusers_load_config)
elif model_type in ['Kolors']: # forced pipeline
from modules.model_kolors import load_kolors
sd_model = load_kolors(checkpoint_info, diffusers_load_config)
elif model_type in ['AuraFlow']: # forced pipeline
from modules.model_auraflow import load_auraflow
sd_model = load_auraflow(checkpoint_info, diffusers_load_config)
elif model_type in ['FLUX']:
from modules.model_flux import load_flux
sd_model = load_flux(checkpoint_info, diffusers_load_config)
elif model_type in ['Stable Diffusion 3']:
from modules.model_sd3 import load_sd3
shared.log.debug(f'Load {op}: model="Stable Diffusion 3"')
shared.opts.scheduler = 'Default'
sd_model = load_sd3(checkpoint_info, cache_dir=shared.opts.diffusers_dir, config=diffusers_load_config.get('config', None))
elif model_type in ['Meissonic']: # forced pipeline
from modules.model_meissonic import load_meissonic
sd_model = load_meissonic(checkpoint_info, diffusers_load_config)
elif model_type in ['OmniGen']: # forced pipeline
from modules.model_omnigen import load_omnigen
sd_model = load_omnigen(checkpoint_info, diffusers_load_config)
except Exception as e:
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
if debug_load:
errors.display(e, 'Load')
return None
return sd_model
def load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_config, op='model'):
sd_model = None
files = shared.walk_files(checkpoint_info.path, ['.safetensors', '.bin', '.ckpt'])
if 'variant' not in diffusers_load_config and any('diffusion_pytorch_model.fp16' in f for f in files): # deal with diffusers lack of variant fallback when loading
diffusers_load_config['variant'] = 'fp16'
if model_type is not None and pipeline is not None and 'ONNX' in model_type: # forced pipeline
try:
sd_model = pipeline.from_pretrained(checkpoint_info.path)
except Exception as e:
shared.log.error(f'Load {op}: type=ONNX path="{checkpoint_info.path}" {e}')
if debug_load:
errors.display(e, 'Load')
return None
else:
err1, err2, err3 = None, None, None
if os.path.exists(checkpoint_info.path) and os.path.isdir(checkpoint_info.path):
if os.path.exists(os.path.join(checkpoint_info.path, 'unet', 'diffusion_pytorch_model.bin')):
shared.log.debug(f'Load {op}: type=pickle')
diffusers_load_config['use_safetensors'] = False
if debug_load:
shared.log.debug(f'Load {op}: args={diffusers_load_config}')
try: # 1 - autopipeline, best choice but not all pipelines are available
try:
sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
except ValueError as e:
if 'no variant default' in str(e):
shared.log.warning(f'Load {op}: variant={diffusers_load_config["variant"]} model="{checkpoint_info.path}" using default variant')
diffusers_load_config.pop('variant', None)
sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
elif 'safetensors found in directory' in str(err1):
shared.log.warning(f'Load {op}: type=pickle')
diffusers_load_config['use_safetensors'] = False
sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
else:
raise ValueError from e # reraise
except Exception as e:
err1 = e
if debug_load:
errors.display(e, 'Load AutoPipeline')
# shared.log.error(f'AutoPipeline: {e}')
try: # 2 - diffusion pipeline, works for most non-linked pipelines
if err1 is not None:
sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
except Exception as e:
err2 = e
if debug_load:
errors.display(e, "Load DiffusionPipeline")
# shared.log.error(f'DiffusionPipeline: {e}')
try: # 3 - try basic pipeline just in case
if err2 is not None:
sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
except Exception as e:
err3 = e # ignore last error
shared.log.error(f"StableDiffusionPipeline: {e}")
if debug_load:
errors.display(e, "Load StableDiffusionPipeline")
if err3 is not None:
shared.log.error(f'Load {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
return None
return sd_model
def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_config, op='model'):
sd_model = None
diffusers_load_config["local_files_only"] = diffusers_version < 28 # must be true for old diffusers, otherwise false but we override config for sd15/sdxl
diffusers_load_config["extract_ema"] = shared.opts.diffusers_extract_ema
if pipeline is None:
shared.log.error(f'Load {op}: pipeline={shared.opts.diffusers_pipeline} not initialized')
return None
try:
if model_type.startswith('Stable Diffusion'):
if shared.opts.diffusers_force_zeros:
diffusers_load_config['force_zeros_for_empty_prompt '] = shared.opts.diffusers_force_zeros
else:
model_config = sd_detect.get_load_config(checkpoint_info.path, model_type, config_type='json')
if model_config is not None:
if debug_load:
shared.log.debug(f'Load {op}: config="{model_config}"')
diffusers_load_config['config'] = model_config
if model_type.startswith('Stable Diffusion 3'):
from modules.model_sd3 import load_sd3
sd_model = load_sd3(checkpoint_info=checkpoint_info, cache_dir=shared.opts.diffusers_dir, config=diffusers_load_config.get('config', None))
elif hasattr(pipeline, 'from_single_file'):
diffusers.loaders.single_file_utils.CHECKPOINT_KEY_NAMES["clip"] = "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight" # patch for diffusers==0.28.0
diffusers_load_config['use_safetensors'] = True
diffusers_load_config['cache_dir'] = shared.opts.hfcache_dir # use hfcache instead of diffusers dir as this is for config only in case of single-file
if shared.opts.disable_accelerate:
from diffusers.utils import import_utils
import_utils._accelerate_available = False # pylint: disable=protected-access
if shared.opts.diffusers_to_gpu and model_type.startswith('Stable Diffusion'):
shared.log.debug(f'Diffusers accelerate: direct={shared.opts.diffusers_to_gpu}')
sd_hijack_accelerate.hijack_accelerate()
else:
sd_hijack_accelerate.restore_accelerate()
sd_model = pipeline.from_single_file(checkpoint_info.path, **diffusers_load_config)
# sd_model = patch_diffuser_config(sd_model, checkpoint_info.path)
elif hasattr(pipeline, 'from_ckpt'):
diffusers_load_config['cache_dir'] = shared.opts.hfcache_dir
sd_model = pipeline.from_ckpt(checkpoint_info.path, **diffusers_load_config)
else:
shared.log.error(f'Load {op}: file="{checkpoint_info.path}" {shared.opts.diffusers_pipeline} cannot load safetensor model')
return None
if shared.opts.diffusers_vae_upcast != 'default' and model_type in ['Stable Diffusion', 'Stable Diffusion XL']:
diffusers_load_config['force_upcast'] = True if shared.opts.diffusers_vae_upcast == 'true' else False
# if debug_load:
# shared.log.debug(f'Model args: {diffusers_load_config}')
if sd_model is not None:
diffusers_load_config.pop('vae', None)
diffusers_load_config.pop('safety_checker', None)
diffusers_load_config.pop('requires_safety_checker', None)
diffusers_load_config.pop('config_files', None)
diffusers_load_config.pop('local_files_only', None)
shared.log.debug(f'Setting {op}: pipeline={sd_model.__class__.__name__} config={diffusers_load_config}') # pylint: disable=protected-access
except Exception as e:
shared.log.error(f'Load {op}: file="{checkpoint_info.path}" pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__} config={diffusers_load_config} {e}')
if 'Weights for this component appear to be missing in the checkpoint' in str(e):
shared.log.error(f'Load {op}: file="{checkpoint_info.path}" is not a complete model')
else:
errors.display(e, 'Load')
return None
return sd_model
def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model', revision=None): # pylint: disable=unused-argument
if timer is None:
timer = Timer()
logging.getLogger("diffusers").setLevel(logging.ERROR)
timer.record("diffusers")
diffusers_load_config = {
"low_cpu_mem_usage": True,
"torch_dtype": devices.dtype,
"load_connected_pipeline": True,
"safety_checker": None, # sd15 specific but we cant know ahead of time
"requires_safety_checker": False, # sd15 specific but we cant know ahead of time
# "use_safetensors": True,
}
if revision is not None:
diffusers_load_config['revision'] = revision
if shared.opts.diffusers_model_load_variant != 'default':
diffusers_load_config['variant'] = shared.opts.diffusers_model_load_variant
if shared.opts.diffusers_pipeline == 'Custom Diffusers Pipeline' and len(shared.opts.custom_diffusers_pipeline) > 0:
shared.log.debug(f'Model pipeline: pipeline="{shared.opts.custom_diffusers_pipeline}"')
diffusers_load_config['custom_pipeline'] = shared.opts.custom_diffusers_pipeline
if shared.opts.data.get('sd_model_checkpoint', '') == 'model.safetensors' or shared.opts.data.get('sd_model_checkpoint', '') == '':
shared.opts.data['sd_model_checkpoint'] = "stabilityai/stable-diffusion-xl-base-1.0"
if (op == 'model' or op == 'dict'):
if (model_data.sd_model is not None) and (checkpoint_info is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
return
else:
if (model_data.sd_refiner is not None) and (checkpoint_info is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
return
sd_model = None
try:
# initial load only
if sd_model is None:
if shared.cmd_opts.ckpt is not None and os.path.isdir(shared.cmd_opts.ckpt) and model_data.initial:
sd_model, checkpoint_info = load_diffuser_initial(diffusers_load_config, op)
# unload current model
checkpoint_info = checkpoint_info or select_checkpoint(op=op)
if checkpoint_info is None:
unload_model_weights(op=op)
return
# detect pipeline
pipeline, model_type = sd_detect.detect_pipeline(checkpoint_info.path, op)
# preload vae so it can be used as param
vae = None
sd_vae.loaded_vae_file = None
if model_type is None:
shared.log.error(f'Load {op}: pipeline={shared.opts.diffusers_pipeline} not detected')
return
vae_file = None
if model_type.startswith('Stable Diffusion') and (op == 'model' or op == 'refiner'): # preload vae for sd models
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
vae = sd_vae.load_vae_diffusers(checkpoint_info.path, vae_file, vae_source)
if vae is not None:
diffusers_load_config["vae"] = vae
timer.record("vae")
# load with custom loader
if sd_model is None:
sd_model = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op)
# load from hf folder-style
if sd_model is None:
if os.path.isdir(checkpoint_info.path) or checkpoint_info.type == 'huggingface' or checkpoint_info.type == 'transformer':
sd_model = load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_config, op)
# load from single-file
if sd_model is None:
if os.path.isfile(checkpoint_info.path) and checkpoint_info.path.lower().endswith('.safetensors'):
sd_model = load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_config, op)
if sd_model is None:
shared.log.error(f'Load {op}: name="{checkpoint_info.name if checkpoint_info is not None else None}" not loaded')
return
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init
sd_model.sd_checkpoint_info = checkpoint_info # pylint: disable=attribute-defined-outside-init
sd_model.sd_model_checkpoint = checkpoint_info.filename # pylint: disable=attribute-defined-outside-init
if hasattr(sd_model, "prior_pipe"):
sd_model.default_scheduler = copy.deepcopy(sd_model.prior_pipe.scheduler) if hasattr(sd_model.prior_pipe, "scheduler") else None
else:
sd_model.default_scheduler = copy.deepcopy(sd_model.scheduler) if hasattr(sd_model, "scheduler") else None
sd_model.is_sdxl = False # a1111 compatibility item
sd_model.is_sd2 = hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model') # a1111 compatibility item
sd_model.is_sd1 = not sd_model.is_sd2 # a1111 compatibility item
sd_model.logvar = sd_model.logvar.to(devices.device) if hasattr(sd_model, 'logvar') else None # fix for training
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
if hasattr(sd_model, "set_progress_bar_config"):
sd_model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining}', ncols=80, colour='#327fba')
if "Kandinsky" in sd_model.__class__.__name__: # need a special case
sd_model.scheduler.name = 'DDIM'
if model_type not in ['Stable Cascade']: # need a special-case
sd_unet.load_unet(sd_model)
timer.record("load")
if op == 'refiner':
model_data.sd_refiner = sd_model
else:
model_data.sd_model = sd_model
reload_text_encoder(initial=True) # must be before embeddings
timer.record("te")
if debug_load:
shared.log.trace(f'Model components: {list(get_signature(sd_model).values())}')
from modules.textual_inversion import textual_inversion
sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
timer.record("embeddings")
from modules import prompt_parser_diffusers
prompt_parser_diffusers.insert_parser_highjack(sd_model.__class__.__name__)
prompt_parser_diffusers.cache.clear()
set_diffuser_options(sd_model, vae, op, offload=False)
if shared.opts.nncf_compress_weights and not ('Model' in shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
sd_model = sd_models_compile.nncf_compress_weights(sd_model) # run this before move model so it can be compressed in CPU
if shared.opts.optimum_quanto_weights:
sd_model = sd_models_compile.optimum_quanto_weights(sd_model) # run this before move model so it can be compressed in CPU
timer.record("options")
set_diffuser_offload(sd_model, op)
if op == 'model' and not (os.path.isdir(checkpoint_info.path) or checkpoint_info.type == 'huggingface'):
if getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None and vae_file is not None:
sd_vae.apply_vae_config(shared.sd_model.sd_checkpoint_info.filename, vae_file, sd_model)
if op == 'refiner' and shared.opts.diffusers_move_refiner:
shared.log.debug('Moving refiner model to CPU')
move_model(sd_model, devices.cpu)
else:
move_model(sd_model, devices.device)
timer.record("move")
if shared.opts.ipex_optimize:
sd_model = sd_models_compile.ipex_optimize(sd_model)
if ('Model' in shared.opts.cuda_compile and shared.opts.cuda_compile_backend != 'none'):
sd_model = sd_models_compile.compile_diffusers(sd_model)
timer.record("compile")
if shared.opts.enable_linfusion:
from modules import linfusion
linfusion.apply(sd_model)
timer.record("linfusion")
except Exception as e:
shared.log.error(f"Load {op}: {e}")
errors.display(e, "Model")
devices.torch_gc(force=True)
if sd_model is not None:
script_callbacks.model_loaded_callback(sd_model)
if debug_load:
from modules import modelstats
modelstats.analyze()
shared.log.info(f"Load {op}: time={timer.summary()} native={get_native(sd_model)} memory={memory_stats()}")
class DiffusersTaskType(Enum):
TEXT_2_IMAGE = 1
IMAGE_2_IMAGE = 2
INPAINTING = 3
INSTRUCT = 4
def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType:
if pipe.__class__.__name__ in ["StableVideoDiffusionPipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", "OmniGenPipeline"]:
return DiffusersTaskType.IMAGE_2_IMAGE
elif pipe.__class__.__name__ == "StableDiffusionXLInstructPix2PixPipeline":
return DiffusersTaskType.INSTRUCT
elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.values():
return DiffusersTaskType.IMAGE_2_IMAGE
elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING.values():
return DiffusersTaskType.INPAINTING
else:
return DiffusersTaskType.TEXT_2_IMAGE
def get_signature(cls):
signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True)
return signature.parameters
def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}):
"""
args:
- cls: can be pipeline class or a string from custom pipelines
for example: diffusers.StableDiffusionPipeline or 'mixture_tiling'
- pipeline: source model to be used, if not provided currently loaded model is used
- args: any additional components to load into the pipeline
for example: { 'vae': None }
"""
try:
if isinstance(cls, str):
shared.log.debug(f'Pipeline switch: custom={cls}')
cls = diffusers.utils.get_class_from_dynamic_module(cls, module_file='pipeline.py')
if pipeline is None:
pipeline = shared.sd_model
new_pipe = None
signature = get_signature(cls)
possible = signature.keys()
if not force and isinstance(pipeline, cls) and args == {}:
return pipeline
pipe_dict = {}
components_used = []
components_skipped = []
components_missing = []
switch_mode = 'none'
if hasattr(pipeline, '_internal_dict'):
for item in pipeline._internal_dict.keys(): # pylint: disable=protected-access
if item in possible:
pipe_dict[item] = getattr(pipeline, item, None)
components_used.append(item)
else:
components_skipped.append(item)
for item in possible:
if item in ['self', 'args', 'kwargs']: # skip
continue
if signature[item].default != inspect._empty: # has default value so we dont have to worry about it # pylint: disable=protected-access
continue
if item not in components_used:
shared.log.warning(f'Pipeling switch: missing component={item} type={signature[item].annotation}')
pipe_dict[item] = None # try but not likely to work
components_missing.append(item)
new_pipe = cls(**pipe_dict)
switch_mode = 'auto'
elif 'tokenizer_2' in possible and hasattr(pipeline, 'tokenizer_2'):
new_pipe = cls(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer,
tokenizer_2=pipeline.tokenizer_2,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
)
move_model(new_pipe, pipeline.device)
switch_mode = 'sdxl'
elif 'tokenizer' in possible and hasattr(pipeline, 'tokenizer'):
new_pipe = cls(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
requires_safety_checker=False,
safety_checker=None,
)
move_model(new_pipe, pipeline.device)
switch_mode = 'sd'
else:
shared.log.error(f'Pipeline switch error: {pipeline.__class__.__name__} unrecognized')
return pipeline
if new_pipe is not None:
for k, v in args.items():
if k in possible:
setattr(new_pipe, k, v)
components_used.append(k)
else:
shared.log.warning(f'Pipeline switch skipping unknown: component={k}')
components_skipped.append(k)
if new_pipe is not None:
copy_diffuser_options(new_pipe, pipeline)
if hasattr(new_pipe, "watermark"):
new_pipe.watermark = NoWatermark()
if switch_mode == 'auto':
shared.log.debug(f'Pipeline switch: from={pipeline.__class__.__name__} to={new_pipe.__class__.__name__} components={components_used} skipped={components_skipped} missing={components_missing}')
else:
shared.log.debug(f'Pipeline switch: from={pipeline.__class__.__name__} to={new_pipe.__class__.__name__} mode={switch_mode}')
return new_pipe
else:
shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} empty pipeline')
except Exception as e:
shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} {e}')
errors.display(e, 'Pipeline switch')
return pipeline
def clean_diffuser_pipe(pipe):
if pipe is not None and shared.sd_model_type == 'sdxl' and hasattr(pipe, 'config') and 'requires_aesthetics_score' in pipe.config and hasattr(pipe, '_internal_dict'):
debug_process(f'Pipeline clean: {pipe.__class__.__name__}')
# diffusers adds requires_aesthetics_score with img2img and complains if requires_aesthetics_score exist in txt2img
internal_dict = dict(pipe._internal_dict) # pylint: disable=protected-access
internal_dict.pop('requires_aesthetics_score', None)
del pipe._internal_dict
pipe.register_to_config(**internal_dict)
def set_diffuser_pipe(pipe, new_pipe_type):
exclude = [
'StableDiffusionReferencePipeline',
'StableDiffusionAdapterPipeline',
'AnimateDiffPipeline',
'AnimateDiffSDXLPipeline',
'OmniGenPipeline',
'StableDiffusion3ControlNetPipeline',
'InstantIRPipeline',
'FluxFillPipeline',
'FluxControlPipeline',
]
n = getattr(pipe.__class__, '__name__', '')
if new_pipe_type == DiffusersTaskType.TEXT_2_IMAGE:
clean_diffuser_pipe(pipe)
if get_diffusers_task(pipe) == new_pipe_type:
return pipe
# skip specific pipelines
cls = pipe.__class__.__name__
if n in exclude:
return pipe
if 'Onnx' in cls:
return pipe
new_pipe = None
# in some cases we want to reset the pipeline to parent as they dont have their own variants
if new_pipe_type == DiffusersTaskType.IMAGE_2_IMAGE or new_pipe_type == DiffusersTaskType.INPAINTING:
if n == 'StableDiffusionPAGPipeline':
pipe = switch_pipe(diffusers.StableDiffusionPipeline, pipe)
if n == 'StableDiffusionXLPAGPipeline':
pipe = switch_pipe(diffusers.StableDiffusionXLPipeline, pipe)
sd_checkpoint_info = getattr(pipe, "sd_checkpoint_info", None)
sd_model_checkpoint = getattr(pipe, "sd_model_checkpoint", None)
embedding_db = getattr(pipe, "embedding_db", None)
sd_model_hash = getattr(pipe, "sd_model_hash", None)
has_accelerate = getattr(pipe, "has_accelerate", None)
current_attn_name = getattr(pipe, "current_attn_name", None)
default_scheduler = getattr(pipe, "default_scheduler", None)
image_encoder = getattr(pipe, "image_encoder", None)
feature_extractor = getattr(pipe, "feature_extractor", None)
if new_pipe is None:
if hasattr(pipe, 'config'): # real pipeline which can be auto-switched
try:
if new_pipe_type == DiffusersTaskType.TEXT_2_IMAGE:
new_pipe = diffusers.AutoPipelineForText2Image.from_pipe(pipe)
elif new_pipe_type == DiffusersTaskType.IMAGE_2_IMAGE:
new_pipe = diffusers.AutoPipelineForImage2Image.from_pipe(pipe)
elif new_pipe_type == DiffusersTaskType.INPAINTING:
new_pipe = diffusers.AutoPipelineForInpainting.from_pipe(pipe)
else:
shared.log.error(f'Pipeline class change failed: type={new_pipe_type} pipeline={cls}')
return pipe
except Exception as e: # pylint: disable=unused-variable
shared.log.warning(f'Pipeline class change failed: type={new_pipe_type} pipeline={cls} {e}')
return pipe
else:
try: # maybe a wrapper pipeline so just change the class
if new_pipe_type == DiffusersTaskType.TEXT_2_IMAGE:
pipe.__class__ = diffusers.pipelines.auto_pipeline._get_task_class(diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING, cls) # pylint: disable=protected-access
new_pipe = pipe
elif new_pipe_type == DiffusersTaskType.IMAGE_2_IMAGE:
pipe.__class__ = diffusers.pipelines.auto_pipeline._get_task_class(diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, cls) # pylint: disable=protected-access
new_pipe = pipe
elif new_pipe_type == DiffusersTaskType.INPAINTING:
pipe.__class__ = diffusers.pipelines.auto_pipeline._get_task_class(diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING, cls) # pylint: disable=protected-access
new_pipe = pipe
else:
shared.log.error(f'Pipeline class change failed: type={new_pipe_type} pipeline={cls}')
return pipe
except Exception as e: # pylint: disable=unused-variable
shared.log.warning(f'Pipeline class set failed: type={new_pipe_type} pipeline={cls} {e}')
return pipe
# if pipe.__class__ == new_pipe.__class__:
# return pipe
new_pipe.sd_checkpoint_info = sd_checkpoint_info
new_pipe.sd_model_checkpoint = sd_model_checkpoint
new_pipe.embedding_db = embedding_db
new_pipe.sd_model_hash = sd_model_hash
new_pipe.has_accelerate = has_accelerate
new_pipe.current_attn_name = current_attn_name
new_pipe.default_scheduler = default_scheduler
new_pipe.image_encoder = image_encoder
new_pipe.feature_extractor = feature_extractor
new_pipe.is_sdxl = getattr(pipe, 'is_sdxl', False) # a1111 compatibility item
new_pipe.is_sd2 = getattr(pipe, 'is_sd2', False)
new_pipe.is_sd1 = getattr(pipe, 'is_sd1', True)
if hasattr(new_pipe, 'watermark'):
new_pipe.watermark = NoWatermark()
if hasattr(new_pipe, 'pipe'): # also handle nested pipelines
new_pipe.pipe = set_diffuser_pipe(new_pipe.pipe, new_pipe_type)
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
shared.log.debug(f"Pipeline class change: original={cls} target={new_pipe.__class__.__name__} device={pipe.device} fn={fn}") # pylint: disable=protected-access
pipe = new_pipe
return pipe
def set_diffusers_attention(pipe):
import diffusers.models.attention_processor as p
def set_attn(pipe, attention):
if attention is None:
return
if not hasattr(pipe, "_internal_dict"):
return
modules = [getattr(pipe, n, None) for n in pipe._internal_dict.keys()] # pylint: disable=protected-access
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attn_processor")]
for module in modules:
if module.__class__.__name__ in ['SD3Transformer2DModel']:
module.set_attn_processor(p.JointAttnProcessor2_0())
elif module.__class__.__name__ in ['FluxTransformer2DModel']:
module.set_attn_processor(p.FluxAttnProcessor2_0())
elif module.__class__.__name__ in ['HunyuanDiT2DModel']:
module.set_attn_processor(p.HunyuanAttnProcessor2_0())
elif module.__class__.__name__ in ['AuraFlowTransformer2DModel']:
module.set_attn_processor(p.AuraFlowAttnProcessor2_0())
elif 'Transformer' in module.__class__.__name__:
pass # unknown transformer so probably dont want to force attention processor
else:
module.set_attn_processor(attention)
# if hasattr(pipe, 'pipe'):
# set_diffusers_attention(pipe.pipe)
if 'ControlNet' in pipe.__class__.__name__: # do not replace attention in ControlNet pipelines
return
shared.log.debug(f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
if shared.opts.cross_attention_optimization == "Disabled":
pass # do nothing
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
set_attn(pipe, p.AttnProcessor2_0())
elif shared.opts.cross_attention_optimization == "xFormers" and hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
pipe.enable_xformers_memory_efficient_attention()
elif shared.opts.cross_attention_optimization == "Split attention" and hasattr(pipe, "enable_attention_slicing"):
pipe.enable_attention_slicing()
elif shared.opts.cross_attention_optimization == "Batch matrix-matrix":
set_attn(pipe, p.AttnProcessor())
elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM":
from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM
set_attn(pipe, DynamicAttnProcessorBMM())
pipe.current_attn_name = shared.opts.cross_attention_optimization
def get_native(pipe: diffusers.DiffusionPipeline):
if hasattr(pipe, "vae") and hasattr(pipe.vae.config, "sample_size"):
# Stable Diffusion
size = pipe.vae.config.sample_size
elif hasattr(pipe, "movq") and hasattr(pipe.movq.config, "sample_size"):
# Kandinsky
size = pipe.movq.config.sample_size
elif hasattr(pipe, "unet") and hasattr(pipe.unet.config, "sample_size"):
size = pipe.unet.config.sample_size
else:
size = 0
return size
def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint(op=op)
if checkpoint_info is None:
return
if op == 'model' or op == 'dict':
if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
return
else:
if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
return
shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}')
if timer is None:
timer = Timer()
current_checkpoint_info = None
if op == 'model' or op == 'dict':
if model_data.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
else:
if model_data.sd_refiner is not None:
sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None)
unload_model_weights(op=op)
if not shared.native:
from modules import sd_hijack_inpainting
sd_hijack_inpainting.do_inpainting_hijack()
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
if state_dict is None or checkpoint_config is None:
shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"')
if current_checkpoint_info is not None:
shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore')
load_model(current_checkpoint_info, None)
return
shared.log.debug(f'Model dict loaded: {memory_stats()}')
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("config")
shared.log.debug(f'Model config loaded: {memory_stats()}')
sd_model = None
stdout = io.StringIO()
if os.environ.get('SD_LDM_DEBUG', None) is not None:
sd_model = instantiate_from_config(sd_config.model)
else:
with contextlib.redirect_stdout(stdout):
sd_model = instantiate_from_config(sd_config.model)
for line in stdout.getvalue().splitlines():
if len(line) > 0:
shared.log.info(f'LDM: {line.strip()}')
shared.log.debug(f"Model created from config: {checkpoint_config}")
sd_model.used_config = checkpoint_config
sd_model.has_accelerate = False
timer.record("create")
ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if not ok:
model_data.sd_model = sd_model
current_checkpoint_info = None
unload_model_weights(op=op)
shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}')
if op == 'refiner':
# shared.opts.data['sd_model_refiner'] = 'None'
shared.opts.sd_model_refiner = 'None'
return
else:
shared.log.debug(f'Model weights loaded: {memory_stats()}')
timer.record("load")
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
move_model(sd_model, devices.device)
timer.record("move")
shared.log.debug(f'Model weights moved: {memory_stats()}')
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval()
if op == 'refiner':
model_data.sd_refiner = sd_model
else:
model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
timer.record("embeddings")
script_callbacks.model_loaded_callback(sd_model)
timer.record("callbacks")
shared.log.info(f"Model loaded in {timer.summary()}")
current_checkpoint_info = None
devices.torch_gc(force=True)
shared.log.info(f'Model load finished: {memory_stats()}')
def reload_text_encoder(initial=False):
if initial and (shared.opts.sd_text_encoder is None or shared.opts.sd_text_encoder == 'None'):
return # dont unload
signature = get_signature(shared.sd_model)
t5 = [k for k, v in signature.items() if 'T5EncoderModel' in str(v)]
if hasattr(shared.sd_model, 'text_encoder') and 'vit' in shared.opts.sd_text_encoder.lower():
from modules.model_te import set_clip
set_clip(pipe=shared.sd_model)
elif len(t5) > 0:
from modules.model_te import set_t5
shared.log.debug(f'Load module: type=t5 path="{shared.opts.sd_text_encoder}" module="{t5[0]}"')
set_t5(pipe=shared.sd_model, module=t5[0], t5=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
elif hasattr(shared.sd_model, 'text_encoder_3'):
from modules.model_te import set_t5
shared.log.debug(f'Load module: type=t5 path="{shared.opts.sd_text_encoder}" module="text_encoder_3"')
set_t5(pipe=shared.sd_model, module='text_encoder_3', t5=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', force=False, revision=None):
load_dict = shared.opts.sd_model_dict != model_data.sd_dict
from modules import lowvram, sd_hijack
checkpoint_info = info or select_checkpoint(op=op) # are we selecting model or dictionary
next_checkpoint_info = info or select_checkpoint(op='dict' if load_dict else 'model') if load_dict else None
if checkpoint_info is None:
unload_model_weights(op=op)
return None
orig_state = copy.deepcopy(shared.state)
shared.state = shared_state.State()
shared.state.begin('Load')
if load_dict:
shared.log.debug(f'Load {op} dict: target="{checkpoint_info.filename}" existing={sd_model is not None} info={info}')
else:
model_data.sd_dict = 'None'
# shared.log.debug(f'Load {op}: target="{checkpoint_info.filename}" existing={sd_model is not None} info={info}')
if sd_model is None:
sd_model = model_data.sd_model if op == 'model' or op == 'dict' else model_data.sd_refiner
if sd_model is None: # previous model load failed
current_checkpoint_info = None
else:
current_checkpoint_info = getattr(sd_model, 'sd_checkpoint_info', None)
if current_checkpoint_info is not None and checkpoint_info is not None and current_checkpoint_info.filename == checkpoint_info.filename and not force:
return None
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
else:
move_model(sd_model, devices.cpu)
if (reuse_dict or shared.opts.model_reuse_dict) and not getattr(sd_model, 'has_accelerate', False):
shared.log.info(f'Load {op}: reusing dictionary')
sd_hijack.model_hijack.undo_hijack(sd_model)
else:
unload_model_weights(op=op)
sd_model = None
timer = Timer()
# TODO implement caching after diffusers implement state_dict loading
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if not shared.native else None
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("config")
if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None) or force:
sd_model = None
if not shared.native:
load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
model_data.sd_dict = shared.opts.sd_model_dict
else:
load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op, revision=revision)
if load_dict and next_checkpoint_info is not None:
model_data.sd_dict = shared.opts.sd_model_dict
shared.opts.data["sd_model_checkpoint"] = next_checkpoint_info.title
reload_model_weights(reuse_dict=True) # ok we loaded dict now lets redo and load model on top of it
shared.state.end()
shared.state = orig_state
# data['sd_model_checkpoint']
if op == 'model' or op == 'dict':
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
return model_data.sd_model
else:
shared.opts.data["sd_model_refiner"] = checkpoint_info.title
return model_data.sd_refiner
# fallback
shared.log.info(f"Load {op} using fallback: model={checkpoint_info.title}")
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception:
shared.log.error("Load model failed: restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
finally:
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
script_callbacks.model_loaded_callback(sd_model)
timer.record("callbacks")
if sd_model is not None and not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
move_model(sd_model, devices.device)
timer.record("device")
shared.state.end()
shared.state = orig_state
shared.log.info(f"Load {op}: time={timer.summary()}")
return sd_model
def convert_to_faketensors(tensor):
try:
fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
if hasattr(tensor, "weight"):
tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
return tensor
except Exception:
pass
return tensor
def disable_offload(sd_model):
from accelerate.hooks import remove_hook_from_module
if not getattr(sd_model, 'has_accelerate', False):
return
if hasattr(sd_model, 'components'):
for _name, model in sd_model.components.items():
if isinstance(model, torch.nn.Module):
remove_hook_from_module(model, recurse=True)
sd_model.has_accelerate = False
def unload_model_weights(op='model'):
if shared.compiled_model_state is not None:
shared.compiled_model_state.compiled_cache.clear()
shared.compiled_model_state.req_cache.clear()
shared.compiled_model_state.partitioned_modules.clear()
if op == 'model' or op == 'dict':
if model_data.sd_model:
if not shared.native:
from modules import sd_hijack
move_model(model_data.sd_model, devices.cpu)
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
elif not ('Model' in shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
disable_offload(model_data.sd_model)
move_model(model_data.sd_model, 'meta')
model_data.sd_model = None
devices.torch_gc(force=True)
shared.log.debug(f'Unload weights {op}: {memory_stats()}')
elif op == 'refiner':
if model_data.sd_refiner:
if not shared.native:
from modules import sd_hijack
move_model(model_data.sd_refiner, devices.cpu)
sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
else:
disable_offload(model_data.sd_refiner)
move_model(model_data.sd_refiner, 'meta')
model_data.sd_refiner = None
devices.torch_gc(force=True)
shared.log.debug(f'Unload weights {op}: {memory_stats()}')
def path_to_repo(fn: str = ''):
if isinstance(fn, CheckpointInfo):
fn = fn.name
repo_id = fn.replace('\\', '/')
if 'models--' in repo_id:
repo_id = repo_id.split('models--')[-1]
repo_id = repo_id.split('/')[0]
repo_id = repo_id.split('/')
repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id)
repo_id = repo_id.replace('models--', '').replace('--', '/')
return repo_id