automatic/modules/model_stablecascade.py

126 lines
6.3 KiB
Python

import os
from modules import shared, devices
def load_text_encoder(path):
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
from accelerate.utils.modeling import set_module_tensor_to_device
from accelerate import init_empty_weights
from safetensors.torch import load_file
try:
config = CLIPTextConfig(
architectures=["CLIPTextModelWithProjection"],
attention_dropout=0.0,
bos_token_id=49406,
dropout=0.0,
eos_token_id=49407,
hidden_act="gelu",
hidden_size=1280,
initializer_factor=1.0,
initializer_range=0.02,
intermediate_size=5120,
layer_norm_eps=1e-05,
max_position_embeddings=77,
model_type="clip_text_model",
num_attention_heads=20,
num_hidden_layers=32,
pad_token_id=1,
projection_dim=1280,
vocab_size=49408
)
shared.log.info(f'Loading Text Encoder: name="{os.path.basename(os.path.splitext(path)[0])}" file="{path}"')
with init_empty_weights():
text_encoder = CLIPTextModelWithProjection(config)
state_dict = load_file(path)
for key in list(state_dict.keys()):
set_module_tensor_to_device(text_encoder, key, devices.device, value=state_dict.pop(key), dtype=devices.dtype)
return text_encoder
except Exception as e:
text_encoder = None
shared.log.error(f'Failed to load Text Encoder model: {e}')
return None
def load_prior(path, config_file="default"):
from diffusers.models.unets import StableCascadeUNet
prior_text_encoder = None
if config_file == "default":
config_file = os.path.splitext(path)[0] + '.json'
if not os.path.exists(config_file):
if round(os.path.getsize(path) / 1024 / 1024 / 1024) < 5: # diffusers fails to find the configs from huggingface
config_file = "configs/stable-cascade/prior_lite/config.json"
else:
config_file = "configs/stable-cascade/prior/config.json"
shared.log.info(f'Loading UNet: name="{os.path.basename(os.path.splitext(path)[0])}" file="{path}" config="{config_file}"')
prior_unet = StableCascadeUNet.from_single_file(path, config=config_file, torch_dtype=devices.dtype_unet, cache_dir=shared.opts.diffusers_dir)
if os.path.isfile(os.path.splitext(path)[0] + "_text_encoder.safetensors"): # OneTrainer
prior_text_encoder = load_text_encoder(os.path.splitext(path)[0] + "_text_encoder.safetensors")
elif os.path.isfile(os.path.splitext(path)[0] + "_text_model.safetensors"): # KohyaSS
prior_text_encoder = load_text_encoder(os.path.splitext(path)[0] + "_text_model.safetensors")
return prior_unet, prior_text_encoder
def load_cascade_combined(checkpoint_info, diffusers_load_config):
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline, StableCascadeCombinedPipeline
from diffusers.models.unets import StableCascadeUNet
from modules.sd_unet import unet_dict
diffusers_load_config.pop("vae", None)
if 'cascade' in checkpoint_info.name.lower():
diffusers_load_config["variant"] = 'bf16'
if shared.opts.sd_unet != "None" or 'stabilityai' in checkpoint_info.name.lower():
if 'cascade' in checkpoint_info.name and ('lite' in checkpoint_info.name or (checkpoint_info.hash is not None and 'abc818bb0d' in checkpoint_info.hash)):
decoder_folder = 'decoder_lite'
prior_folder = 'prior_lite'
else:
decoder_folder = 'decoder'
prior_folder = 'prior'
if 'cascade' in checkpoint_info.name.lower():
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder=decoder_folder, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", cache_dir=shared.opts.diffusers_dir, decoder=decoder_unet, text_encoder=None, **diffusers_load_config)
else:
decoder = StableCascadeDecoderPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, text_encoder=None, **diffusers_load_config)
shared.log.debug(f'StableCascade {decoder_folder}: scale={decoder.latent_dim_scale}')
prior_text_encoder = None
if shared.opts.sd_unet != "None":
prior_unet, prior_text_encoder = load_prior(unet_dict[shared.opts.sd_unet])
else:
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder=prior_folder, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
if prior_text_encoder is not None:
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", cache_dir=shared.opts.diffusers_dir, prior=prior_unet, text_encoder=prior_text_encoder, image_encoder=None, feature_extractor=None, **diffusers_load_config)
else:
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", cache_dir=shared.opts.diffusers_dir, prior=prior_unet, image_encoder=None, feature_extractor=None, **diffusers_load_config)
shared.log.debug(f'StableCascade {prior_folder}: scale={prior.resolution_multiple}')
sd_model = StableCascadeCombinedPipeline(
tokenizer=decoder.tokenizer,
text_encoder=None,
decoder=decoder.decoder,
scheduler=decoder.scheduler,
vqgan=decoder.vqgan,
prior_prior=prior.prior,
prior_text_encoder=prior.text_encoder,
prior_tokenizer=prior.tokenizer,
prior_scheduler=prior.scheduler,
prior_feature_extractor=None,
prior_image_encoder=None)
else:
sd_model = StableCascadeCombinedPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.decoder_pipe.text_encoder = sd_model.text_encoder = None # Nothing uses the decoder's text encoder
sd_model.prior_pipe.image_encoder = sd_model.prior_image_encoder = None # No img2img is implemented yet
sd_model.prior_pipe.feature_extractor = sd_model.prior_feature_extractor = None # No img2img is implemented yet
shared.log.debug(f'StableCascade combined: {sd_model.__class__.__name__}')
return sd_model