mirror of https://github.com/vladmandic/automatic
99 lines
4.0 KiB
Python
99 lines
4.0 KiB
Python
import os
|
|
import transformers
|
|
import diffusers
|
|
from huggingface_hub import repo_exists
|
|
from modules import errors, shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant
|
|
|
|
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
|
|
|
|
|
def load_lumina(_checkpoint_info, diffusers_load_config={}):
|
|
modelloader.hf_login()
|
|
load_config, _quant_config = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
|
pipe = diffusers.LuminaText2ImgPipeline.from_pretrained(
|
|
'Alpha-VLLM/Lumina-Next-SFT-diffusers',
|
|
cache_dir = shared.opts.diffusers_dir,
|
|
**load_config,
|
|
)
|
|
devices.torch_gc(force=True)
|
|
return pipe
|
|
|
|
|
|
def load_lumina2(checkpoint_info, diffusers_load_config={}):
|
|
transformer, text_encoder, vae = None, None, None
|
|
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
|
if os.path.isdir(checkpoint_info.filename) and not repo_exists(repo_id):
|
|
repo_id = checkpoint_info.filename
|
|
|
|
if shared.opts.teacache_enabled:
|
|
from modules import teacache
|
|
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.Lumina2Transformer2DModel.__name__}')
|
|
diffusers.Lumina2Transformer2DModel.forward = teacache.teacache_lumina2_forward # patch must be done before transformer is loaded
|
|
|
|
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
|
if shared.opts.sd_unet != 'Default':
|
|
try:
|
|
debug(f'Load model: type=Lumina2 unet="{shared.opts.sd_unet}"')
|
|
transformer = diffusers.Lumina2Transformer2DModel.from_single_file(
|
|
sd_unet.unet_dict[shared.opts.sd_unet],
|
|
cache_dir=shared.opts.diffusers_dir,
|
|
**load_config,
|
|
**quant_config
|
|
)
|
|
if transformer is None:
|
|
shared.opts.sd_unet = 'Default'
|
|
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
|
except Exception as e:
|
|
shared.log.error(f"Load model: type=Lumina2 failed to load UNet: {e}")
|
|
shared.opts.sd_unet = 'Default'
|
|
if debug:
|
|
errors.display(e, 'Lumina2 UNet:')
|
|
|
|
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
|
try:
|
|
debug(f'Load model: type=Lumina2 vae="{shared.opts.sd_vae}"')
|
|
from modules import sd_vae
|
|
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
|
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
|
if os.path.exists(vae_file):
|
|
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
|
|
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
|
except Exception as e:
|
|
shared.log.error(f"Load model: type=Lumina2 failed to load VAE: {e}")
|
|
shared.opts.sd_vae = 'Default'
|
|
if debug:
|
|
errors.display(e, 'Lumina2 VAE:')
|
|
|
|
if transformer is None:
|
|
transformer = diffusers.Lumina2Transformer2DModel.from_pretrained(
|
|
repo_id,
|
|
subfolder="transformer",
|
|
cache_dir=shared.opts.diffusers_dir,
|
|
**load_config,
|
|
**quant_config,
|
|
)
|
|
|
|
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
|
|
text_encoder = transformers.AutoModel.from_pretrained(
|
|
repo_id,
|
|
subfolder="text_encoder",
|
|
cache_dir=shared.opts.diffusers_dir,
|
|
**load_config,
|
|
**quant_config,
|
|
)
|
|
|
|
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
|
if vae is not None:
|
|
load_config['vae'] = vae
|
|
pipe = diffusers.Lumina2Pipeline.from_pretrained(
|
|
repo_id,
|
|
cache_dir=shared.opts.diffusers_dir,
|
|
text_encoder=text_encoder,
|
|
transformer=transformer,
|
|
**load_config,
|
|
)
|
|
|
|
sd_hijack_te.init_hijack(pipe)
|
|
devices.torch_gc(force=True)
|
|
return pipe
|