import os import diffusers import transformers from modules import shared, devices, sd_models, sd_unet, model_quant, model_tools def load_overrides(kwargs, cache_dir): if shared.opts.sd_unet != 'None': try: fn = sd_unet.unet_dict[shared.opts.sd_unet] if fn.endswith('.safetensors'): kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_single_file(fn, cache_dir=cache_dir, torch_dtype=devices.dtype) sd_unet.loaded_unet = shared.opts.sd_unet shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=safetensors') elif fn.endswith('.gguf'): from modules import ggml # kwargs = load_gguf(kwargs, fn) kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype) sd_unet.loaded_unet = shared.opts.sd_unet shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=gguf') except Exception as e: shared.log.error(f"Load model: type=SD3 failed to load UNet: {e}") shared.opts.sd_unet = 'None' sd_unet.failed_unet.append(shared.opts.sd_unet) if shared.opts.sd_text_encoder != 'None': try: from modules.model_te import load_t5, load_vit_l, load_vit_g if 'vit-l' in shared.opts.sd_text_encoder.lower(): kwargs['text_encoder'] = load_vit_l() shared.log.debug(f'Load model: type=SD3 variant="vit-l" te="{shared.opts.sd_text_encoder}"') elif 'vit-g' in shared.opts.sd_text_encoder.lower(): kwargs['text_encoder_2'] = load_vit_g() shared.log.debug(f'Load model: type=SD3 variant="vit-g" te="{shared.opts.sd_text_encoder}"') else: kwargs['text_encoder_3'] = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) shared.log.debug(f'Load model: type=SD3 variant="t5" te="{shared.opts.sd_text_encoder}"') except Exception as e: shared.log.error(f"Load model: type=SD3 failed to load T5: {e}") shared.opts.sd_text_encoder = 'None' if shared.opts.sd_vae != 'None' and shared.opts.sd_vae != 'Automatic': try: from modules import sd_vae vae_file = sd_vae.vae_dict[shared.opts.sd_vae] if os.path.exists(vae_file): vae_config = os.path.join('configs', 'sd3', 'vae', 'config.json') kwargs['vae'] = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, cache_dir=cache_dir, torch_dtype=devices.dtype) shared.log.debug(f'Load model: type=SD3 vae="{shared.opts.sd_vae}"') except Exception as e: shared.log.error(f"Load model: type=SD3 failed to load VAE: {e}") shared.opts.sd_vae = 'None' return kwargs def load_quants(kwargs, repo_id, cache_dir): quant_args = {} quant_args = model_quant.create_bnb_config(quant_args) if quant_args: model_quant.load_bnb(f'Load model: type=SD3 quant={quant_args}') if not quant_args: quant_args = model_quant.create_ao_config(quant_args) if quant_args: model_quant.load_torchao(f'Load model: type=SD3 quant={quant_args}') if not quant_args: return kwargs if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') if 'text_encoder_3' not in kwargs and ('Text Encoder' in shared.opts.bnb_quantization or 'Text Encoder' in shared.opts.torchao_quantization): kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') return kwargs def load_missing(kwargs, fn, cache_dir): keys = model_tools.get_safetensor_keys(fn) size = os.stat(fn).st_size // 1024 // 1024 if size > 15000: repo_id = 'stabilityai/stable-diffusion-3.5-large' else: repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers' if 'text_encoder' not in kwargs and 'text_encoder' not in keys: kwargs['text_encoder'] = transformers.CLIPTextModelWithProjection.from_pretrained(repo_id, subfolder='text_encoder', cache_dir=cache_dir, torch_dtype=devices.dtype) shared.log.debug(f'Load model: type=SD3 missing=te1 repo="{repo_id}"') if 'text_encoder_2' not in kwargs and 'text_encoder_2' not in keys: kwargs['text_encoder_2'] = transformers.CLIPTextModelWithProjection.from_pretrained(repo_id, subfolder='text_encoder_2', cache_dir=cache_dir, torch_dtype=devices.dtype) shared.log.debug(f'Load model: type=SD3 missing=te2 repo="{repo_id}"') if 'text_encoder_3' not in kwargs and 'text_encoder_3' not in keys: kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype) shared.log.debug(f'Load model: type=SD3 missing=te3 repo="{repo_id}"') if 'vae' not in kwargs and 'vae' not in keys: kwargs['vae'] = diffusers.AutoencoderKL.from_pretrained(repo_id, subfolder='vae', cache_dir=cache_dir, torch_dtype=devices.dtype) shared.log.debug(f'Load model: type=SD3 missing=vae repo="{repo_id}"') # if 'transformer' not in kwargs and 'transformer' not in keys: # kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(default_repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype) return kwargs """ def load_gguf(kwargs, fn): ggml.install_gguf() from accelerate import init_empty_weights from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers from modules import ggml, sd_hijack_accelerate with init_empty_weights(): config = diffusers.SD3Transformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer") transformer = diffusers.SD3Transformer2DModel.from_config(config).to(devices.dtype) expected_state_dict_keys = list(transformer.state_dict().keys()) state_dict, stats = ggml.load_gguf_state_dict(fn, devices.dtype) state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict) applied, skipped = 0, 0 for param_name, param in state_dict.items(): if param_name not in expected_state_dict_keys: skipped += 1 continue applied += 1 sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0) transformer.gguf = 'gguf' state_dict[param_name] = None shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats} compute={devices.dtype}') kwargs['transformer'] = transformer return kwargs """ def load_sd3(checkpoint_info, cache_dir=None, config=None): repo_id = sd_models.path_to_repo(checkpoint_info.name) fn = checkpoint_info.path # unload current model sd_models.unload_model_weights() shared.sd_model = None devices.torch_gc(force=True) kwargs = {} kwargs = load_overrides(kwargs, cache_dir) if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): kwargs = load_quants(kwargs, repo_id, cache_dir) loader = diffusers.StableDiffusion3Pipeline.from_pretrained if fn is not None and os.path.exists(fn) and os.path.isfile(fn): if fn.endswith('.safetensors'): loader = diffusers.StableDiffusion3Pipeline.from_single_file # required_modules = model_tools.get_modules(diffusers.StableDiffusion3Pipeline) # have_modules = model_tools.get_safetensor_keys(fn) # loaded_modules = model_tools.load_modules('stabilityai/stable-diffusion-3.5-medium', required_modules) # kwargs = {**kwargs, **loaded_modules} # kwargs = load_missing(kwargs, fn, cache_dir) repo_id = fn elif fn.endswith('.gguf'): from modules import ggml kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype) # kwargs = load_gguf(kwargs, fn) kwargs = load_missing(kwargs, fn, cache_dir) kwargs['variant'] = 'fp16' else: kwargs['variant'] = 'fp16' shared.log.debug(f'Load model: type=SD3 kwargs={list(kwargs)} repo="{repo_id}"') kwargs = model_quant.create_bnb_config(kwargs) kwargs = model_quant.create_ao_config(kwargs) pipe = loader( repo_id, torch_dtype=devices.dtype, cache_dir=cache_dir, config=config, **kwargs, ) devices.torch_gc() return pipe