cleanup flux loader

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4115/head
Vladimir Mandic 2025-08-11 13:38:06 -04:00
parent c92e329234
commit 87bd347116
19 changed files with 739 additions and 520 deletions

View File

@ -2,13 +2,11 @@
"""
Warnings:
- fal/AuraFlow-v0.3: layer_class_name=Linear layer_weight_shape=torch.Size([3072, 2, 1024]) weights_dtype=int8 unsupported
- Kwai-Kolors/Kolors-diffusers: `set_input_embeddings` not autohandled for ChatGLMModel
- kandinsky-community/kandinsky-2-1: `get_input_embeddings` not autohandled for MultilingualCLIP
- Kwai-Kolors/Kolors-diffusers: set_input_embeddings not autohandled for ChatGLMModel
- kandinsky-community/kandinsky-2-1: get_input_embeddings not autohandled for MultilingualCLIP
Errors:
- kandinsky-community/kandinsky-3: corrupt output
- nvidia/Cosmos-Predict2-2B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x2048)
- nvidia/Cosmos-Predict2-14B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x5120)
- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error: device-side assert triggered
- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error device-side assert triggered
Other:
- HiDream-ai/HiDream-I1-Full: very slow at 30+s/it
"""
@ -64,7 +62,11 @@ models = {
"stabilityai/stable-cascade": {},
"nvidia/Cosmos-Predict2-2B-Text2Image": {},
"nvidia/Cosmos-Predict2-14B-Text2Image": {},
# "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {},
"black-forest-labs/FLUX.1-dev": {},
"black-forest-labs/FLUX.1-Kontext-dev": {},
"black-forest-labs/FLUX.1-Krea-dev": {},
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {},
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers": {},
# "kandinsky-community/kandinsky-3": {},
# "HiDream-ai/HiDream-I1-Full": {},
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {},
@ -76,11 +78,6 @@ models = {
"vladmandic/chroma-unlocked-v48": {},
"vladmandic/chroma-unlocked-v48-detail-calibrated": {},
}
models_tbd = [
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-Kontext-dev",
"black-forest-labs/FLUX.1-Krea-dev",
]
styles = [
'Fixed Astronaut',
]
@ -115,7 +112,7 @@ def read_history():
log.info(f'history: file="{fn}" records={len(history)}')
def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration:float=0, info:str=''):
def write_history(model:str, style:str, image:str='', size:tuple=(0,0), generate:float=0, load:float=0, info:str=''):
fn = os.path.join(output_folder, 'history.json')
history.append({
'model': model,
@ -123,7 +120,8 @@ def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration
'style': style,
'image': image,
'size': size,
'time': duration,
'time': generate,
'load': load,
'info': info,
})
with open(fn, "w", encoding='utf8') as file:
@ -147,12 +145,17 @@ def request(endpoint: str, dct: dict = None, method: str = 'POST'):
return req.json()
def generate(): # pylint: disable=redefined-outer-name
idx = 0
def main(): # pylint: disable=redefined-outer-name
idx_model = 0
idx_images = 0
t_generate0 = time.time()
log.info(f'generate: models={len(models)} styles={len(styles)}')
for model, args in models.items():
idx += 1
t_model0 = time.time()
idx_model += 1
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
log.info(f'model: n={idx+1}/{len(models)} name="{model}"')
log.info(f'model: n={idx_model+1}/{len(models)} name="{model}"')
idx_style = 0
for s, style in enumerate(styles):
try:
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
@ -160,39 +163,46 @@ def generate(): # pylint: disable=redefined-outer-name
fn = os.path.join(output_folder, f'{model_name}__{style_name}.jpg')
if os.path.exists(fn):
continue
t_load0 = time.time()
request(f'/sdapi/v1/checkpoint?sd_model_checkpoint={model}', method='POST')
loaded = request('/sdapi/v1/checkpoint', method='GET')
t_load1 = time.time()
if not loaded or not (model in loaded.get('checkpoint') or model in loaded.get('title') or model in loaded.get('name')):
log.error(f' model: error="{model}"')
continue
t0 = time.time()
t_style0 = time.time()
params = { 'styles': [style] }
for k, v in args.items():
params[k] = v
log.info(f' style: n={s+1}/{len(styles)} name="{style}" args={params} fn="{fn}"')
data = request('/sdapi/v1/txt2img', params)
t1 = time.time()
t_style1 = time.time()
if 'images' in data and len(data['images']) > 0:
idx_style += 1
idx_images += 1
b64 = data['images'][0].split(',',1)[0]
image = Image.open(io.BytesIO(base64.b64decode(b64)))
info = data['info']
log.info(f' image: size={image.width}x{image.height} time={t1-t0:.2f} info={len(info)}')
log.info(f' image: size={image.width}x{image.height} time={t_style1-t_style0:.2f} info={len(info)}')
image.save(fn)
write_history(model=model, style=style, image=fn, size=image.size, duration=round(t1-t0, 3), info=info)
write_history(model=model, style=style, image=fn, size=image.size, generate=round(t_style1-t_style0, 3), load=round(t_load1-t_load0, 3), info=info)
else:
# write_history(model=model, style=style, duration=round(t1-t0, 3), info='no image')
log.error(f' model: error="{model}" style="{style}" no image')
except Exception as e:
if 'Connection refused' in str(e) or 'RemoteDisconnected' in str(e):
log.error('server offline')
os._exit(1)
# write_history(model=model, style=style, duration=round(t1-t0, 3), info=str(e))
log.error(f' model: error="{model}" style="{style}" exception="{e}"')
t_model1 = time.time()
if idx_style > 0:
log.info(f'model: name="{model}" images={idx_style} time={t_model1-t_model0:.2f}')
t_generate1 = time.time()
if idx_images > 0:
log.info(f'generate: models={idx_model} images={idx_images} time={t_generate1-t_generate0:.2f}')
if __name__ == "__main__":
log.info('test-all-models')
log.info(f'output="{output_folder}" models={len(models)} styles={len(styles)}')
log.info(f'output="{output_folder}"')
read_history()
generate()
log.info('done...')
main()

View File

@ -429,6 +429,12 @@
"preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg",
"extras": "sampler: Default, cfg_scale: 2.0"
},
"Tencent HunyuanDiT 1.1": {
"path": "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers",
"desc": "Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding.",
"preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg",
"extras": "sampler: Default, cfg_scale: 2.0"
},
"AlphaVLLM Lumina Next SFT": {
"path": "Alpha-VLLM/Lumina-Next-SFT-diffusers",

View File

@ -642,11 +642,15 @@ def get_dit_args(load_config:dict={}, module:str=None, device_map:bool=False, al
def do_post_load_quant(sd_model, allow=True):
from modules import shared
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
shared.log.debug('Load model: post_quant=sdnq')
sd_model = sdnq_quantize_weights(sd_model)
if len(shared.opts.optimum_quanto_weights) > 0:
shared.log.debug('Load model: post_quant=quanto')
sd_model = optimum_quanto_weights(sd_model)
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
shared.log.debug('Load model: post_quant=torchao')
sd_model = torchao_quantization(sd_model)
if shared.opts.layerwise_quantization:
shared.log.debug('Load model: post_quant=layerwise')
apply_layerwise(sd_model)
return sd_model

View File

@ -29,6 +29,7 @@ debug_load = os.environ.get('SD_LOAD_DEBUG', None)
debug_process = log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
diffusers_version = int(diffusers.__version__.split('.')[1])
checkpoint_tiles = checkpoint_titles # legacy compatibility
allow_post_quant = None
pipe_switch_task_exclude = [
'AnimateDiffPipeline', 'AnimateDiffSDXLPipeline',
'FluxControlPipeline',
@ -275,7 +276,7 @@ def load_diffuser_initial(diffusers_load_config, op='model'):
def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='model'):
sd_model = None
allow_post_quant = True
global allow_post_quant # pylint: disable=global-statement
unload_model_weights(op=op)
shared.sd_model = None
try:
@ -316,7 +317,8 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
allow_post_quant = True
elif model_type in ['FLUX']:
from pipelines.model_flux import load_flux
sd_model, allow_post_quant = load_flux(checkpoint_info, diffusers_load_config)
sd_model = load_flux(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['FLEX']:
from pipelines.model_flex import load_flex
sd_model = load_flex(checkpoint_info, diffusers_load_config)
@ -398,7 +400,7 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
if debug_load:
errors.display(e, 'Load')
return None, True
return sd_model, allow_post_quant
return sd_model
def load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_config, op='model'):
@ -561,6 +563,8 @@ def set_defaults(sd_model, checkpoint_info):
def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: disable=unused-argument
global allow_post_quant # pylint: disable=global-statement
allow_post_quant = True # assume default
logging.getLogger("diffusers").setLevel(logging.ERROR)
timer.load.record("diffusers")
diffusers_load_config = {
@ -589,7 +593,6 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
return
sd_model = None
allow_post_quant = True
try:
# initial load only
if sd_model is None:
@ -621,7 +624,7 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
# load with custom loader
if sd_model is None:
sd_model, allow_post_quant = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op)
sd_model = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op)
if sd_model is not None and not sd_model:
shared.log.error(f'Load {op}: type="{model_type}" pipeline="{pipeline}" not loaded')
return

View File

@ -0,0 +1,25 @@
import diffusers
import transformers
from modules import devices, model_quant
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
return transformer

View File

@ -0,0 +1,360 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder_2 = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
try:
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
transformer, text_encoder_2 = None, None
if debug:
errors.display(e, 'FLUX:')
return transformer, text_encoder_2
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
try:
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": cache_dir,
}
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
if 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
kwargs['transformer'].quantization_method = 'SVDQuant'
if shared.opts.nunchaku_attention:
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
except Exception as e:
shared.log.error(f'Quantization: {e}')
errors.display(e, 'Quantization:')
return kwargs
def load_transformer(file_path): # triggered by opts.sd_unet change
if file_path is None or not os.path.exists(file_path):
return None
transformer = None
quant = model_quant.get_quant(file_path)
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
if quant is not None and quant != 'none':
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == "fp8":
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif quant in {'qint8', 'qint4'}:
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant in {'fp8', 'fp4', 'nf4'}:
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = model_quant.create_bnb_config({})
if quant_args:
shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
if transformer is None:
shared.log.error('Failed to load UNet model')
shared.opts.sd_unet = 'Default'
return transformer
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
allow_post_quant = False
prequantized = model_quant.get_quant(checkpoint_info.path)
shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
debug(f'Load model: type=FLUX config={diffusers_load_config}')
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
# unload current model
sd_models.unload_model_weights()
shared.sd_model = None
devices.torch_gc(force=True, reason='load')
if shared.opts.teacache_enabled:
from modules import teacache
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
# load overrides if any
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
if transformer is None:
shared.opts.sd_unet = 'Default'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'FLUX UNet:')
if shared.opts.sd_text_encoder != 'Default':
try:
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
from modules.model_te import load_t5, load_vit_l
if 'vit-l' in shared.opts.sd_text_encoder.lower():
text_encoder_1 = load_vit_l()
else:
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load T5: {e}")
shared.opts.sd_text_encoder = 'Default'
if debug:
errors.display(e, 'FLUX T5:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
if os.path.exists(vae_file):
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'FLUX VAE:')
# load quantized components if any
if prequantized == 'nf4':
try:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
if debug:
errors.display(e, 'FLUX NF4:')
if prequantized == 'qint8' or prequantized == 'qint4':
try:
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
# initialize pipeline with pre-loaded components
kwargs = {}
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder_1 is not None:
kwargs['text_encoder'] = text_encoder_1
model_te.loaded_te = shared.opts.sd_text_encoder
if text_encoder_2 is not None:
kwargs['text_encoder_2'] = text_encoder_2
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
if 'Fill' in repo_id:
cls = diffusers.FluxFillPipeline
elif 'Canny' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Depth' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Kontext' in repo_id:
cls = diffusers.FluxKontextPipeline
from diffusers import pipelines
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
else:
cls = diffusers.FluxPipeline
shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
fn = checkpoint_info.path
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
if fn.endswith('.safetensors') and os.path.isfile(fn):
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
allow_post_quant = True
else:
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
# release memory
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
for k in kwargs.keys():
kwargs[k] = None
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe, allow_post_quant

View File

@ -0,0 +1,29 @@
from modules import shared, devices
def load_flux_nunchaku(repo_id):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
transformer = None
if 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
transformer = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
transformer.quantization_method = 'SVDQuant'
if shared.opts.nunchaku_attention:
transformer.set_attention_impl("nunchaku-fp16")
return transformer

View File

@ -0,0 +1,73 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, model_quant
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2

View File

@ -2,120 +2,145 @@ import os
import json
import diffusers
import transformers
from modules import shared, devices, sd_models, model_quant
from modules import shared, devices, errors, sd_models, model_quant
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
def load_transformer(repo_id, cls_name, load_config={}, subfolder="transformer", allow_quant=True, variant=None, dtype=None):
load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant)
quant_type = model_quant.get_quant_type(quant_args)
dtype = dtype or devices.dtype
transformer = None
try:
load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant)
quant_type = model_quant.get_quant_type(quant_args)
dtype = dtype or devices.dtype
local_file = None
if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default':
from modules import sd_unet
if shared.opts.sd_unet not in list(sd_unet.unet_dict):
shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found')
elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]):
local_file = sd_unet.unet_dict[shared.opts.sd_unet]
local_file = None
if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default':
from modules import sd_unet
if shared.opts.sd_unet not in list(sd_unet.unet_dict):
shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found')
elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]):
local_file = sd_unet.unet_dict[shared.opts.sd_unet]
if local_file is not None and local_file.lower().endswith('.gguf'):
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
from modules import ggml
ggml.install_gguf()
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
transformer = loader(
local_file,
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
elif local_file is not None and local_file.lower().endswith('.safetensors'):
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
transformer = loader(
local_file,
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
else:
shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
if dtype is not None:
load_args['torch_dtype'] = dtype
if subfolder is not None:
load_args['subfolder'] = subfolder
if variant is not None:
load_args['variant'] = variant
transformer = cls_name.from_pretrained(
repo_id,
cache_dir=shared.opts.hfcache_dir,
**load_args,
**quant_args,
)
if shared.opts.diffusers_offload_mode != 'none' and transformer is not None:
sd_models.move_model(transformer, devices.cpu)
if local_file is not None and local_file.lower().endswith('.gguf'):
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
from modules import ggml
ggml.install_gguf()
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
transformer = loader(
local_file,
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
elif local_file is not None and local_file.lower().endswith('.safetensors'):
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
transformer = loader(
local_file,
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
else:
shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
if dtype is not None:
load_args['torch_dtype'] = dtype
if subfolder is not None:
load_args['subfolder'] = subfolder
if variant is not None:
load_args['variant'] = variant
transformer = cls_name.from_pretrained(
repo_id,
cache_dir=shared.opts.hfcache_dir,
**load_args,
**quant_args,
)
sd_models.allow_post_quant = False # we already handled it
if shared.opts.diffusers_offload_mode != 'none' and transformer is not None:
sd_models.move_model(transformer, devices.cpu)
except Exception as e:
shared.log.error(f'Load model: type=transformer {e}')
if debug:
errors.display(e, 'Load:')
raise
return transformer
def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder", allow_quant=True, allow_shared=True, variant=None, dtype=None):
load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant)
quant_type = model_quant.get_quant_type(quant_args)
text_encoder = None
dtype = dtype or devices.dtype
try:
load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant)
quant_type = model_quant.get_quant_type(quant_args)
dtype = dtype or devices.dtype
# load from local file if specified
local_file = None
if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default':
from modules import model_te
if shared.opts.sd_text_encoder not in list(model_te.te_dict):
shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found')
elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]):
local_file = model_te.te_dict[shared.opts.sd_text_encoder]
# load from local file if specified
local_file = None
if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default':
from modules import model_te
if shared.opts.sd_text_encoder not in list(model_te.te_dict):
shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found')
elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]):
local_file = model_te.te_dict[shared.opts.sd_text_encoder]
# load from local file gguf
if local_file is not None and local_file.lower().endswith('.gguf'):
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
from modules import ggml
ggml.install_gguf()
text_encoder = cls_name.from_pretrained(
gguf_file=local_file,
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
# load from local file safetensors
elif local_file is not None and local_file.lower().endswith('.safetensors'):
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
text_encoder = cls_name.from_pretrained(
local_file,
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
# use shared t5 if possible
elif cls_name == transformers.T5EncoderModel and allow_shared:
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
load_args['config'] = transformers.T5Config(**json.load(f))
if model_quant.check_nunchaku('TE'):
import nunchaku
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
cls_name = nunchaku.NunchakuT5EncoderModel
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"')
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(
repo_id,
torch_dtype=dtype,
# load from local file gguf
if local_file is not None and local_file.lower().endswith('.gguf'):
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
from modules import ggml
ggml.install_gguf()
text_encoder = cls_name.from_pretrained(
gguf_file=local_file,
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
text_encoder.quantization_method = 'SVDQuant'
elif shared.opts.te_shared_t5:
repo_id = 'Disty0/t5-xxl'
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
# load from local file safetensors
elif local_file is not None and local_file.lower().endswith('.safetensors'):
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
text_encoder = cls_name.from_pretrained(
local_file,
cache_dir=shared.opts.hfcache_dir,
**load_args,
)
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
# use shared t5 if possible
elif cls_name == transformers.T5EncoderModel and allow_shared:
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
load_args['config'] = transformers.T5Config(**json.load(f))
if model_quant.check_nunchaku('TE'):
import nunchaku
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
cls_name = nunchaku.NunchakuT5EncoderModel
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"')
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(
repo_id,
torch_dtype=dtype,
)
text_encoder.quantization_method = 'SVDQuant'
elif shared.opts.te_shared_t5:
repo_id = 'Disty0/t5-xxl'
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
if dtype is not None:
load_args['torch_dtype'] = dtype
text_encoder = cls_name.from_pretrained(
repo_id,
cache_dir=shared.opts.hfcache_dir,
**load_args,
**quant_args,
)
# load from repo
if text_encoder is None:
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
if dtype is not None:
load_args['torch_dtype'] = dtype
if subfolder is not None:
load_args['subfolder'] = subfolder
if variant is not None:
load_args['variant'] = variant
text_encoder = cls_name.from_pretrained(
repo_id,
cache_dir=shared.opts.hfcache_dir,
@ -123,22 +148,12 @@ def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder
**quant_args,
)
# load from repo
if text_encoder is None:
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
if dtype is not None:
load_args['torch_dtype'] = dtype
if subfolder is not None:
load_args['subfolder'] = subfolder
if variant is not None:
load_args['variant'] = variant
text_encoder = cls_name.from_pretrained(
repo_id,
cache_dir=shared.opts.hfcache_dir,
**load_args,
**quant_args,
)
if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None:
sd_models.move_model(text_encoder, devices.cpu)
sd_models.allow_post_quant = False # we already handled it
if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None:
sd_models.move_model(text_encoder, devices.cpu)
except Exception as e:
shared.log.error(f'Load model: type=te {e}')
if debug:
errors.display(e, 'Load:')
raise
return text_encoder

View File

@ -4,7 +4,7 @@ sys.path.append("./")
# import torch
# from torchvision import transforms
from meissonic.transformer import Transformer2DModel as TransformerMeissonic
from meissonic.pipeline import Pipeline as PipelineMeissonic
from meissonic.pipeline import MeissonicPipeline
from meissonic.scheduler import Scheduler as MeissonicScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel
@ -21,7 +21,7 @@ vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", cache_dir=cach
text_encoder = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
scheduler = MeissonicScheduler.from_pretrained(model_path, subfolder="scheduler")
pipe = PipelineMeissonic(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
pipe = MeissonicPipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
pipe = pipe.to(device)
steps = 64

View File

@ -1,13 +1,9 @@
import os
import diffusers
import transformers
from modules import shared, devices, sd_models, model_quant
from pipelines import generic
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_chroma(checkpoint_info, diffusers_load_config={}):
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)

View File

@ -1,360 +1,76 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
from modules import shared, devices, sd_models, model_quant
from pipelines import generic
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder_2 = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
try:
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
transformer, text_encoder_2 = None, None
if debug:
errors.display(e, 'FLUX:')
return transformer, text_encoder_2
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
try:
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": cache_dir,
}
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
if 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
kwargs['transformer'].quantization_method = 'SVDQuant'
if shared.opts.nunchaku_attention:
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
except Exception as e:
shared.log.error(f'Quantization: {e}')
errors.display(e, 'Quantization:')
return kwargs
def load_transformer(file_path): # triggered by opts.sd_unet change
if file_path is None or not os.path.exists(file_path):
return None
transformer = None
quant = model_quant.get_quant(file_path)
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
if quant is not None and quant != 'none':
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == "fp8":
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif quant in {'qint8', 'qint4'}:
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant in {'fp8', 'fp4', 'nf4'}:
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant:
from pipelines.model_flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = model_quant.create_bnb_config({})
if quant_args:
shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from pipelines.model_flux_nf4 import load_flux_nf4
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
if transformer is None:
shared.log.error('Failed to load UNet model')
shared.opts.sd_unet = 'Default'
return transformer
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
def load_flux(checkpoint_info, diffusers_load_config={}):
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
allow_post_quant = False
prequantized = model_quant.get_quant(checkpoint_info.path)
shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
debug(f'Load model: type=FLUX config={diffusers_load_config}')
if 'Fill' in repo_id:
cls_name = diffusers.FluxFillPipeline
elif 'Canny' in repo_id:
cls_name = diffusers.FluxControlPipeline
elif 'Depth' in repo_id:
cls_name = diffusers.FluxControlPipeline
elif 'Kontext' in repo_id:
cls_name = diffusers.FluxKontextPipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
else:
cls_name = diffusers.FluxPipeline
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=Flux repo="{repo_id}" cls={cls_name.__name__} config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
# unload current model
sd_models.unload_model_weights()
shared.sd_model = None
devices.torch_gc(force=True, reason='load')
if shared.opts.teacache_enabled:
# optional teacache patch
if shared.opts.teacache_enabled and not model_quant.check_nunchaku('Model'):
from modules import teacache
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
# load overrides if any
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
if transformer is None:
shared.opts.sd_unet = 'Default'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'FLUX UNet:')
if shared.opts.sd_text_encoder != 'Default':
try:
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
from modules.model_te import load_t5, load_vit_l
if 'vit-l' in shared.opts.sd_text_encoder.lower():
text_encoder_1 = load_vit_l()
else:
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load T5: {e}")
shared.opts.sd_text_encoder = 'Default'
if debug:
errors.display(e, 'FLUX T5:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
if os.path.exists(vae_file):
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'FLUX VAE:')
transformer = None
text_encoder_2 = None
# load quantized components if any
if prequantized == 'nf4':
try:
from pipelines.model_flux_nf4 import load_flux_nf4
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
if debug:
errors.display(e, 'FLUX NF4:')
if prequantized == 'qint8' or prequantized == 'qint4':
try:
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
# handle transformer svdquant if available, t5 is handled inside load_text_encoder
prequantized = model_quant.get_quant(checkpoint_info.path)
if model_quant.check_nunchaku('Model'):
from pipelines.flux.flux_nunchaku import load_flux_nunchaku
transformer = load_flux_nunchaku(repo_id)
# handle prequantized models
elif prequantized == 'nf4':
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
elif prequantized == 'qint8' or prequantized == 'qint4':
from pipelines.flux.flux_quanto import load_flux_quanto
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
elif prequantized == 'fp4' or prequantized == 'fp8':
from pipelines.flux.flux_bnb import load_flux_bnb
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
# initialize pipeline with pre-loaded components
kwargs = {}
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder_1 is not None:
kwargs['text_encoder'] = text_encoder_1
model_te.loaded_te = shared.opts.sd_text_encoder
if text_encoder_2 is not None:
kwargs['text_encoder_2'] = text_encoder_2
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
if 'Fill' in repo_id:
cls = diffusers.FluxFillPipeline
elif 'Canny' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Depth' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Kontext' in repo_id:
cls = diffusers.FluxKontextPipeline
from diffusers import pipelines
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
# finally load transformer and text encoder if not already loaded
if transformer is None:
transformer = generic.load_transformer(repo_id, cls_name=diffusers.FluxTransformer2DModel, load_config=diffusers_load_config)
if text_encoder_2 is None:
text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config)
else:
cls = diffusers.FluxPipeline
shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass
pipe = cls_name.from_pretrained(
repo_id,
transformer=transformer,
text_encoder_2=text_encoder_2,
cache_dir=shared.opts.diffusers_dir,
**load_args,
)
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
fn = checkpoint_info.path
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
if fn.endswith('.safetensors') and os.path.isfile(fn):
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
allow_post_quant = True
else:
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
del text_encoder_2
del transformer
# optional first-block patch
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
# release memory
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
for k in kwargs.keys():
kwargs[k] = None
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe, allow_post_quant
return pipe

View File

@ -12,7 +12,8 @@ def load_hunyuandit(checkpoint_info, diffusers_load_config={}):
shared.log.debug(f'Load model: type=HunyuanDiT repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
transformer = generic.load_transformer(repo_id, cls_name=diffusers.HunyuanDiT2DModel, load_config=diffusers_load_config)
text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2")
repo_te = 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers' if 'HunyuanDiT-v1' in repo_id else repo_id
text_encoder_2 = generic.load_text_encoder(repo_te, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2", allow_shared=False) # this is not normal t5
pipe = diffusers.HunyuanDiTPipeline.from_pretrained(
repo_id,

View File

@ -2,21 +2,14 @@ import torch
import diffusers
repo_id = 'Kwai-Kolors/Kolors-diffusers'
def load_kolors(_checkpoint_info, diffusers_load_config={}):
from modules import shared, devices
diffusers_load_config['variant'] = "fp16"
if 'torch_dtype' not in diffusers_load_config:
diffusers_load_config['torch_dtype'] = torch.float16
# import torch
# import transformers
# encoder_id = 'THUDM/chatglm3-6b'
# text_encoder = transformers.AutoModel.from_pretrained(encoder_id, torch_dtype=torch.float16, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir)
# text_encoder = transformers.AutoModel.from_pretrained("THUDM/chatglm3-6b", torch_dtype=torch.float16, trust_remote_code=True).quantize(4).cuda()
# tokenizer = transformers.AutoTokenizer.from_pretrained(encoder_id, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir)
repo_id = 'Kwai-Kolors/Kolors-diffusers'
shared.log.debug(f'Load model: type=Kolors repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
pipe = diffusers.KolorsPipeline.from_pretrained(
repo_id,
cache_dir = shared.opts.diffusers_dir,

View File

@ -1,9 +1,9 @@
import os
import transformers
import diffusers
from modules import errors, shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant
from modules import shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_lumina(_checkpoint_info, diffusers_load_config={}):
@ -30,7 +30,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model')
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=Lumina2 unet="{shared.opts.sd_unet}"')
transformer = diffusers.Lumina2Transformer2DModel.from_single_file(
sd_unet.unet_dict[shared.opts.sd_unet],
cache_dir=shared.opts.diffusers_dir,
@ -43,12 +42,9 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
except Exception as e:
shared.log.error(f"Load model: type=Lumina2 failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'Lumina2 UNet:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=Lumina2 vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
@ -58,8 +54,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
except Exception as e:
shared.log.error(f"Load model: type=Lumina2 failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'Lumina2 VAE:')
if transformer is None:
transformer = diffusers.Lumina2Transformer2DModel.from_pretrained(

View File

@ -6,10 +6,10 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
from modules import shared, devices, modelloader, sd_models, shared_items
from pipelines.meissonic.transformer import Transformer2DModel as TransformerMeissonic
from pipelines.meissonic.scheduler import Scheduler as MeissonicScheduler
from pipelines.meissonic.pipeline import Pipeline as PipelineMeissonic
from pipelines.meissonic.pipeline_img2img import Img2ImgPipeline as PipelineMeissonicImg2Img
from pipelines.meissonic.pipeline_inpaint import InpaintPipeline as PipelineMeissonicInpaint
shared_items.pipelines['Meissonic'] = PipelineMeissonic
from pipelines.meissonic.pipeline import MeissonicPipeline
from pipelines.meissonic.pipeline_img2img import MeissonicImg2ImgPipeline
from pipelines.meissonic.pipeline_inpaint import MeissonicInpaintPipeline
shared_items.pipelines['Meissonic'] = MeissonicPipeline
modelloader.hf_login()
fn = sd_models.path_to_repo(checkpoint_info)
@ -41,7 +41,7 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
cache_dir=cache_dir,
)
scheduler = MeissonicScheduler.from_pretrained(fn, subfolder="scheduler", cache_dir=cache_dir)
pipe = PipelineMeissonic(
pipe = MeissonicPipeline(
vqvae=vqvae.to(devices.dtype),
text_encoder=text_encoder.to(devices.dtype),
transformer=model.to(devices.dtype),
@ -49,8 +49,8 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
scheduler=scheduler,
)
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonic
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicImg2Img
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicInpaint
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicImg2ImgPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = MeissonicInpaintPipeline
devices.torch_gc(force=True, reason='load')
return pipe

View File

@ -1,9 +1,6 @@
import os
import diffusers
from modules import shared, devices, sd_models, model_quant
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
repo_id = sd_models.path_to_repo(checkpoint_info)

View File

@ -1,8 +1,5 @@
import os
from modules import shared, devices, sd_models, model_quant
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
repo_id = sd_models.path_to_repo(checkpoint_info)