diff --git a/cli/test-all-models.py b/cli/test-all-models.py index 288c0e0c7..728bc9930 100755 --- a/cli/test-all-models.py +++ b/cli/test-all-models.py @@ -2,13 +2,11 @@ """ Warnings: - fal/AuraFlow-v0.3: layer_class_name=Linear layer_weight_shape=torch.Size([3072, 2, 1024]) weights_dtype=int8 unsupported -- Kwai-Kolors/Kolors-diffusers: `set_input_embeddings` not auto‑handled for ChatGLMModel -- kandinsky-community/kandinsky-2-1: `get_input_embeddings` not auto‑handled for MultilingualCLIP +- Kwai-Kolors/Kolors-diffusers: set_input_embeddings not autohandled for ChatGLMModel +- kandinsky-community/kandinsky-2-1: get_input_embeddings not autohandled for MultilingualCLIP Errors: - kandinsky-community/kandinsky-3: corrupt output -- nvidia/Cosmos-Predict2-2B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x2048) -- nvidia/Cosmos-Predict2-14B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x5120) -- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error: device-side assert triggered +- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error device-side assert triggered Other: - HiDream-ai/HiDream-I1-Full: very slow at 30+s/it """ @@ -64,7 +62,11 @@ models = { "stabilityai/stable-cascade": {}, "nvidia/Cosmos-Predict2-2B-Text2Image": {}, "nvidia/Cosmos-Predict2-14B-Text2Image": {}, - # "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {}, + "black-forest-labs/FLUX.1-dev": {}, + "black-forest-labs/FLUX.1-Kontext-dev": {}, + "black-forest-labs/FLUX.1-Krea-dev": {}, + "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {}, + "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers": {}, # "kandinsky-community/kandinsky-3": {}, # "HiDream-ai/HiDream-I1-Full": {}, "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {}, @@ -76,11 +78,6 @@ models = { "vladmandic/chroma-unlocked-v48": {}, "vladmandic/chroma-unlocked-v48-detail-calibrated": {}, } -models_tbd = [ - "black-forest-labs/FLUX.1-dev", - "black-forest-labs/FLUX.1-Kontext-dev", - "black-forest-labs/FLUX.1-Krea-dev", -] styles = [ 'Fixed Astronaut', ] @@ -115,7 +112,7 @@ def read_history(): log.info(f'history: file="{fn}" records={len(history)}') -def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration:float=0, info:str=''): +def write_history(model:str, style:str, image:str='', size:tuple=(0,0), generate:float=0, load:float=0, info:str=''): fn = os.path.join(output_folder, 'history.json') history.append({ 'model': model, @@ -123,7 +120,8 @@ def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration 'style': style, 'image': image, 'size': size, - 'time': duration, + 'time': generate, + 'load': load, 'info': info, }) with open(fn, "w", encoding='utf8') as file: @@ -147,12 +145,17 @@ def request(endpoint: str, dct: dict = None, method: str = 'POST'): return req.json() -def generate(): # pylint: disable=redefined-outer-name - idx = 0 +def main(): # pylint: disable=redefined-outer-name + idx_model = 0 + idx_images = 0 + t_generate0 = time.time() + log.info(f'generate: models={len(models)} styles={len(styles)}') for model, args in models.items(): - idx += 1 + t_model0 = time.time() + idx_model += 1 model_name = pathvalidate.sanitize_filename(model, replacement_text='_') - log.info(f'model: n={idx+1}/{len(models)} name="{model}"') + log.info(f'model: n={idx_model+1}/{len(models)} name="{model}"') + idx_style = 0 for s, style in enumerate(styles): try: model_name = pathvalidate.sanitize_filename(model, replacement_text='_') @@ -160,39 +163,46 @@ def generate(): # pylint: disable=redefined-outer-name fn = os.path.join(output_folder, f'{model_name}__{style_name}.jpg') if os.path.exists(fn): continue + t_load0 = time.time() request(f'/sdapi/v1/checkpoint?sd_model_checkpoint={model}', method='POST') loaded = request('/sdapi/v1/checkpoint', method='GET') + t_load1 = time.time() if not loaded or not (model in loaded.get('checkpoint') or model in loaded.get('title') or model in loaded.get('name')): log.error(f' model: error="{model}"') continue - t0 = time.time() + t_style0 = time.time() params = { 'styles': [style] } for k, v in args.items(): params[k] = v log.info(f' style: n={s+1}/{len(styles)} name="{style}" args={params} fn="{fn}"') data = request('/sdapi/v1/txt2img', params) - t1 = time.time() + t_style1 = time.time() if 'images' in data and len(data['images']) > 0: + idx_style += 1 + idx_images += 1 b64 = data['images'][0].split(',',1)[0] image = Image.open(io.BytesIO(base64.b64decode(b64))) info = data['info'] - log.info(f' image: size={image.width}x{image.height} time={t1-t0:.2f} info={len(info)}') + log.info(f' image: size={image.width}x{image.height} time={t_style1-t_style0:.2f} info={len(info)}') image.save(fn) - write_history(model=model, style=style, image=fn, size=image.size, duration=round(t1-t0, 3), info=info) + write_history(model=model, style=style, image=fn, size=image.size, generate=round(t_style1-t_style0, 3), load=round(t_load1-t_load0, 3), info=info) else: - # write_history(model=model, style=style, duration=round(t1-t0, 3), info='no image') log.error(f' model: error="{model}" style="{style}" no image') except Exception as e: if 'Connection refused' in str(e) or 'RemoteDisconnected' in str(e): log.error('server offline') os._exit(1) - # write_history(model=model, style=style, duration=round(t1-t0, 3), info=str(e)) log.error(f' model: error="{model}" style="{style}" exception="{e}"') + t_model1 = time.time() + if idx_style > 0: + log.info(f'model: name="{model}" images={idx_style} time={t_model1-t_model0:.2f}') + t_generate1 = time.time() + if idx_images > 0: + log.info(f'generate: models={idx_model} images={idx_images} time={t_generate1-t_generate0:.2f}') if __name__ == "__main__": log.info('test-all-models') - log.info(f'output="{output_folder}" models={len(models)} styles={len(styles)}') + log.info(f'output="{output_folder}"') read_history() - generate() - log.info('done...') + main() diff --git a/html/reference.json b/html/reference.json index a8c3447bb..5fd79f36a 100644 --- a/html/reference.json +++ b/html/reference.json @@ -429,6 +429,12 @@ "preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg", "extras": "sampler: Default, cfg_scale: 2.0" }, + "Tencent HunyuanDiT 1.1": { + "path": "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", + "desc": "Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding.", + "preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg", + "extras": "sampler: Default, cfg_scale: 2.0" + }, "AlphaVLLM Lumina Next SFT": { "path": "Alpha-VLLM/Lumina-Next-SFT-diffusers", diff --git a/modules/model_quant.py b/modules/model_quant.py index a2cfbecbf..00f42e016 100644 --- a/modules/model_quant.py +++ b/modules/model_quant.py @@ -642,11 +642,15 @@ def get_dit_args(load_config:dict={}, module:str=None, device_map:bool=False, al def do_post_load_quant(sd_model, allow=True): from modules import shared if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')): + shared.log.debug('Load model: post_quant=sdnq') sd_model = sdnq_quantize_weights(sd_model) if len(shared.opts.optimum_quanto_weights) > 0: + shared.log.debug('Load model: post_quant=quanto') sd_model = optimum_quanto_weights(sd_model) if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')): + shared.log.debug('Load model: post_quant=torchao') sd_model = torchao_quantization(sd_model) if shared.opts.layerwise_quantization: + shared.log.debug('Load model: post_quant=layerwise') apply_layerwise(sd_model) return sd_model diff --git a/modules/sd_models.py b/modules/sd_models.py index ec11f2c1e..817367b3a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -29,6 +29,7 @@ debug_load = os.environ.get('SD_LOAD_DEBUG', None) debug_process = log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None diffusers_version = int(diffusers.__version__.split('.')[1]) checkpoint_tiles = checkpoint_titles # legacy compatibility +allow_post_quant = None pipe_switch_task_exclude = [ 'AnimateDiffPipeline', 'AnimateDiffSDXLPipeline', 'FluxControlPipeline', @@ -275,7 +276,7 @@ def load_diffuser_initial(diffusers_load_config, op='model'): def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='model'): sd_model = None - allow_post_quant = True + global allow_post_quant # pylint: disable=global-statement unload_model_weights(op=op) shared.sd_model = None try: @@ -316,7 +317,8 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' allow_post_quant = True elif model_type in ['FLUX']: from pipelines.model_flux import load_flux - sd_model, allow_post_quant = load_flux(checkpoint_info, diffusers_load_config) + sd_model = load_flux(checkpoint_info, diffusers_load_config) + allow_post_quant = False elif model_type in ['FLEX']: from pipelines.model_flex import load_flex sd_model = load_flex(checkpoint_info, diffusers_load_config) @@ -398,7 +400,7 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' if debug_load: errors.display(e, 'Load') return None, True - return sd_model, allow_post_quant + return sd_model def load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_config, op='model'): @@ -561,6 +563,8 @@ def set_defaults(sd_model, checkpoint_info): def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: disable=unused-argument + global allow_post_quant # pylint: disable=global-statement + allow_post_quant = True # assume default logging.getLogger("diffusers").setLevel(logging.ERROR) timer.load.record("diffusers") diffusers_load_config = { @@ -589,7 +593,6 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di return sd_model = None - allow_post_quant = True try: # initial load only if sd_model is None: @@ -621,7 +624,7 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di # load with custom loader if sd_model is None: - sd_model, allow_post_quant = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op) + sd_model = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op) if sd_model is not None and not sd_model: shared.log.error(f'Load {op}: type="{model_type}" pipeline="{pipeline}" not loaded') return diff --git a/pipelines/flux/flux_bnb.py b/pipelines/flux/flux_bnb.py new file mode 100644 index 000000000..777678af1 --- /dev/null +++ b/pipelines/flux/flux_bnb.py @@ -0,0 +1,25 @@ +import diffusers +import transformers +from modules import devices, model_quant + + +def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument + transformer = None + if isinstance(checkpoint_info, str): + repo_path = checkpoint_info + else: + repo_path = checkpoint_info.path + model_quant.load_bnb('Load model: type=FLUX') + quant = model_quant.get_quant(repo_path) + if quant == 'fp8': + quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype) + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + elif quant == 'fp4': + quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4') + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + elif quant == 'nf4': + quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4') + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + else: + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config) + return transformer diff --git a/pipelines/flux/flux_legacy_loader.py b/pipelines/flux/flux_legacy_loader.py new file mode 100644 index 000000000..6b6f9d294 --- /dev/null +++ b/pipelines/flux/flux_legacy_loader.py @@ -0,0 +1,360 @@ +import os +import json +import torch +import diffusers +import transformers +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download +from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te + + +debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None + + +def load_flux_quanto(checkpoint_info): + transformer, text_encoder_2 = None, None + quanto = model_quant.load_quanto('Load model: type=FLUX') + + if isinstance(checkpoint_info, str): + repo_path = checkpoint_info + else: + repo_path = checkpoint_info.path + + try: + quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json") + debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"') + if not os.path.exists(quantization_map): + repo_id = sd_models.path_to_repo(checkpoint_info) + quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) + with open(quantization_map, "r", encoding='utf8') as f: + quantization_map = json.load(f) + state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors")) + dtype = state_dict['context_embedder.bias'].dtype + with torch.device("meta"): + transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype) + quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu")) + transformer_dtype = transformer.dtype + if transformer_dtype != devices.dtype: + try: + transformer = transformer.to(dtype=devices.dtype) + except Exception: + shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}") + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}") + if debug: + errors.display(e, 'FLUX Quanto:') + + try: + quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json") + debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"') + if not os.path.exists(quantization_map): + repo_id = sd_models.path_to_repo(checkpoint_info) + quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) + with open(quantization_map, "r", encoding='utf8') as f: + quantization_map = json.load(f) + with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f: + t5_config = transformers.T5Config(**json.load(f)) + state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors")) + dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype + with torch.device("meta"): + text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype) + quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu")) + text_encoder_2_dtype = text_encoder_2.dtype + if text_encoder_2_dtype != devices.dtype: + try: + text_encoder_2 = text_encoder_2.to(dtype=devices.dtype) + except Exception: + shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}") + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}") + if debug: + errors.display(e, 'FLUX Quanto:') + + return transformer, text_encoder_2 + + +def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument + transformer, text_encoder_2 = None, None + if isinstance(checkpoint_info, str): + repo_path = checkpoint_info + else: + repo_path = checkpoint_info.path + model_quant.load_bnb('Load model: type=FLUX') + quant = model_quant.get_quant(repo_path) + try: + if quant == 'fp8': + quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype) + debug(f'Quantization: {quantization_config}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + elif quant == 'fp4': + quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4') + debug(f'Quantization: {quantization_config}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + elif quant == 'nf4': + quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4') + debug(f'Quantization: {quantization_config}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) + else: + transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config) + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}") + transformer, text_encoder_2 = None, None + if debug: + errors.display(e, 'FLUX:') + return transformer, text_encoder_2 + + +def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument + try: + diffusers_load_config = { + "torch_dtype": devices.dtype, + "cache_dir": cache_dir, + } + if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'): + import nunchaku + nunchaku_precision = nunchaku.utils.get_precision() + nunchaku_repo = None + if 'flux.1-kontext' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors" + elif 'flux.1-dev' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors" + elif 'flux.1-schnell' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'flux.1-fill' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'flux.1-depth' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'shuttle' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors" + else: + shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported') + if nunchaku_repo is not None: + shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}') + kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype) + kwargs['transformer'].quantization_method = 'SVDQuant' + if shared.opts.nunchaku_attention: + kwargs['transformer'].set_attention_impl("nunchaku-fp16") + if 'transformer' not in kwargs and model_quant.check_quant('Model'): + load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True) + kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args) + if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'): + import nunchaku + nunchaku_precision = nunchaku.utils.get_precision() + nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' + shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}') + kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype) + kwargs['text_encoder_2'].quantization_method = 'SVDQuant' + if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'): + load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True) + kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args) + except Exception as e: + shared.log.error(f'Quantization: {e}') + errors.display(e, 'Quantization:') + return kwargs + + +def load_transformer(file_path): # triggered by opts.sd_unet change + if file_path is None or not os.path.exists(file_path): + return None + transformer = None + quant = model_quant.get_quant(file_path) + diffusers_load_config = { + "torch_dtype": devices.dtype, + "cache_dir": shared.opts.hfcache_dir, + } + if quant is not None and quant != 'none': + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}') + if 'gguf' in file_path.lower(): + from modules import ggml + _transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype) + if _transformer is not None: + transformer = _transformer + elif quant == "fp8": + _transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config) + if _transformer is not None: + transformer = _transformer + elif quant in {'qint8', 'qint4'}: + _transformer, _text_encoder_2 = load_flux_quanto(file_path) + if _transformer is not None: + transformer = _transformer + elif quant in {'fp8', 'fp4', 'nf4'}: + _transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config) + if _transformer is not None: + transformer = _transformer + elif 'nf4' in quant: + from pipelines.flux.flux_nf4 import load_flux_nf4 + _transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True) + if _transformer is not None: + transformer = _transformer + else: + quant_args = model_quant.create_bnb_config({}) + if quant_args: + shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}') + from pipelines.flux.flux_nf4 import load_flux_nf4 + transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False) + if transformer is not None: + return transformer + load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True) + shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args) + if transformer is None: + shared.log.error('Failed to load UNet model') + shared.opts.sd_unet = 'Default' + return transformer + + +def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change + repo_id = sd_models.path_to_repo(checkpoint_info) + sd_models.hf_auth_check(checkpoint_info) + allow_post_quant = False + + prequantized = model_quant.get_quant(checkpoint_info.path) + shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}') + debug(f'Load model: type=FLUX config={diffusers_load_config}') + + transformer = None + text_encoder_1 = None + text_encoder_2 = None + vae = None + + # unload current model + sd_models.unload_model_weights() + shared.sd_model = None + devices.torch_gc(force=True, reason='load') + + if shared.opts.teacache_enabled: + from modules import teacache + shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}') + diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded + + # load overrides if any + if shared.opts.sd_unet != 'Default': + try: + debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"') + transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet]) + if transformer is None: + shared.opts.sd_unet = 'Default' + sd_unet.failed_unet.append(shared.opts.sd_unet) + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}") + shared.opts.sd_unet = 'Default' + if debug: + errors.display(e, 'FLUX UNet:') + if shared.opts.sd_text_encoder != 'Default': + try: + debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"') + from modules.model_te import load_t5, load_vit_l + if 'vit-l' in shared.opts.sd_text_encoder.lower(): + text_encoder_1 = load_vit_l() + else: + text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load T5: {e}") + shared.opts.sd_text_encoder = 'Default' + if debug: + errors.display(e, 'FLUX T5:') + if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic': + try: + debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"') + from modules import sd_vae + # vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override') + vae_file = sd_vae.vae_dict[shared.opts.sd_vae] + if os.path.exists(vae_file): + vae_config = os.path.join('configs', 'flux', 'vae', 'config.json') + vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config) + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}") + shared.opts.sd_vae = 'Default' + if debug: + errors.display(e, 'FLUX VAE:') + + # load quantized components if any + if prequantized == 'nf4': + try: + from pipelines.flux.flux_nf4 import load_flux_nf4 + _transformer, _text_encoder = load_flux_nf4(checkpoint_info) + if _transformer is not None: + transformer = _transformer + if _text_encoder is not None: + text_encoder_2 = _text_encoder + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}") + if debug: + errors.display(e, 'FLUX NF4:') + if prequantized == 'qint8' or prequantized == 'qint4': + try: + _transformer, _text_encoder = load_flux_quanto(checkpoint_info) + if _transformer is not None: + transformer = _transformer + if _text_encoder is not None: + text_encoder_2 = _text_encoder + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}") + if debug: + errors.display(e, 'FLUX Quanto:') + + # initialize pipeline with pre-loaded components + kwargs = {} + if transformer is not None: + kwargs['transformer'] = transformer + sd_unet.loaded_unet = shared.opts.sd_unet + if text_encoder_1 is not None: + kwargs['text_encoder'] = text_encoder_1 + model_te.loaded_te = shared.opts.sd_text_encoder + if text_encoder_2 is not None: + kwargs['text_encoder_2'] = text_encoder_2 + model_te.loaded_te = shared.opts.sd_text_encoder + if vae is not None: + kwargs['vae'] = vae + if repo_id == 'sayakpaul/flux.1-dev-nf4': + repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json + if 'Fill' in repo_id: + cls = diffusers.FluxFillPipeline + elif 'Canny' in repo_id: + cls = diffusers.FluxControlPipeline + elif 'Depth' in repo_id: + cls = diffusers.FluxControlPipeline + elif 'Kontext' in repo_id: + cls = diffusers.FluxKontextPipeline + from diffusers import pipelines + pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline + pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline + pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline + + else: + cls = diffusers.FluxPipeline + shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}') + for c in kwargs: + if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None: + shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}') + if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32: + try: + kwargs[c] = kwargs[c].to(dtype=devices.dtype) + shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast') + except Exception: + pass + + allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none') + fn = checkpoint_info.path + if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): + kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant) + if fn.endswith('.safetensors') and os.path.isfile(fn): + pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) + allow_post_quant = True + else: + pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) + + if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'): + from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe + apply_cache_on_pipe(pipe, residual_diff_threshold=0.12) + + # release memory + transformer = None + text_encoder_1 = None + text_encoder_2 = None + vae = None + for k in kwargs.keys(): + kwargs[k] = None + sd_hijack_te.init_hijack(pipe) + devices.torch_gc(force=True, reason='load') + return pipe, allow_post_quant diff --git a/pipelines/model_flux_nf4.py b/pipelines/flux/flux_nf4.py similarity index 100% rename from pipelines/model_flux_nf4.py rename to pipelines/flux/flux_nf4.py diff --git a/pipelines/flux/flux_nunchaku.py b/pipelines/flux/flux_nunchaku.py new file mode 100644 index 000000000..e21b93a3b --- /dev/null +++ b/pipelines/flux/flux_nunchaku.py @@ -0,0 +1,29 @@ +from modules import shared, devices + + +def load_flux_nunchaku(repo_id): + import nunchaku + nunchaku_precision = nunchaku.utils.get_precision() + nunchaku_repo = None + transformer = None + if 'flux.1-kontext' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors" + elif 'flux.1-dev' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors" + elif 'flux.1-schnell' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'flux.1-fill' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'flux.1-depth' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + elif 'shuttle' in repo_id.lower(): + nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors" + else: + shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported') + if nunchaku_repo is not None: + shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}') + transformer = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype) + transformer.quantization_method = 'SVDQuant' + if shared.opts.nunchaku_attention: + transformer.set_attention_impl("nunchaku-fp16") + return transformer diff --git a/pipelines/flux/flux_quanto.py b/pipelines/flux/flux_quanto.py new file mode 100644 index 000000000..11e604b62 --- /dev/null +++ b/pipelines/flux/flux_quanto.py @@ -0,0 +1,73 @@ +import os +import json +import torch +import diffusers +import transformers +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download +from modules import shared, errors, devices, sd_models, model_quant + + +debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None + + +def load_flux_quanto(checkpoint_info): + transformer, text_encoder_2 = None, None + quanto = model_quant.load_quanto('Load model: type=FLUX') + + if isinstance(checkpoint_info, str): + repo_path = checkpoint_info + else: + repo_path = checkpoint_info.path + + try: + quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json") + debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"') + if not os.path.exists(quantization_map): + repo_id = sd_models.path_to_repo(checkpoint_info) + quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) + with open(quantization_map, "r", encoding='utf8') as f: + quantization_map = json.load(f) + state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors")) + dtype = state_dict['context_embedder.bias'].dtype + with torch.device("meta"): + transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype) + quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu")) + transformer_dtype = transformer.dtype + if transformer_dtype != devices.dtype: + try: + transformer = transformer.to(dtype=devices.dtype) + except Exception: + shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}") + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}") + if debug: + errors.display(e, 'FLUX Quanto:') + + try: + quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json") + debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"') + if not os.path.exists(quantization_map): + repo_id = sd_models.path_to_repo(checkpoint_info) + quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) + with open(quantization_map, "r", encoding='utf8') as f: + quantization_map = json.load(f) + with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f: + t5_config = transformers.T5Config(**json.load(f)) + state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors")) + dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype + with torch.device("meta"): + text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype) + quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu")) + text_encoder_2_dtype = text_encoder_2.dtype + if text_encoder_2_dtype != devices.dtype: + try: + text_encoder_2 = text_encoder_2.to(dtype=devices.dtype) + except Exception: + shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}") + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}") + if debug: + errors.display(e, 'FLUX Quanto:') + + return transformer, text_encoder_2 diff --git a/pipelines/generic.py b/pipelines/generic.py index 102b5e6b8..d4065cb50 100644 --- a/pipelines/generic.py +++ b/pipelines/generic.py @@ -2,120 +2,145 @@ import os import json import diffusers import transformers -from modules import shared, devices, sd_models, model_quant +from modules import shared, devices, errors, sd_models, model_quant -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = os.environ.get('SD_LOAD_DEBUG', None) is not None def load_transformer(repo_id, cls_name, load_config={}, subfolder="transformer", allow_quant=True, variant=None, dtype=None): - load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant) - quant_type = model_quant.get_quant_type(quant_args) - dtype = dtype or devices.dtype + transformer = None + try: + load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant) + quant_type = model_quant.get_quant_type(quant_args) + dtype = dtype or devices.dtype - local_file = None - if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default': - from modules import sd_unet - if shared.opts.sd_unet not in list(sd_unet.unet_dict): - shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found') - elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]): - local_file = sd_unet.unet_dict[shared.opts.sd_unet] + local_file = None + if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default': + from modules import sd_unet + if shared.opts.sd_unet not in list(sd_unet.unet_dict): + shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found') + elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]): + local_file = sd_unet.unet_dict[shared.opts.sd_unet] - if local_file is not None and local_file.lower().endswith('.gguf'): - shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') - from modules import ggml - ggml.install_gguf() - loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained - transformer = loader( - local_file, - quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype), - cache_dir=shared.opts.hfcache_dir, - **load_args, - ) - transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None) - elif local_file is not None and local_file.lower().endswith('.safetensors'): - shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') - loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained - transformer = loader( - local_file, - cache_dir=shared.opts.hfcache_dir, - **load_args, - ) - transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None) - else: - shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') - if dtype is not None: - load_args['torch_dtype'] = dtype - if subfolder is not None: - load_args['subfolder'] = subfolder - if variant is not None: - load_args['variant'] = variant - transformer = cls_name.from_pretrained( - repo_id, - cache_dir=shared.opts.hfcache_dir, - **load_args, - **quant_args, - ) - if shared.opts.diffusers_offload_mode != 'none' and transformer is not None: - sd_models.move_model(transformer, devices.cpu) + if local_file is not None and local_file.lower().endswith('.gguf'): + shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') + from modules import ggml + ggml.install_gguf() + loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained + transformer = loader( + local_file, + quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype), + cache_dir=shared.opts.hfcache_dir, + **load_args, + ) + transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None) + elif local_file is not None and local_file.lower().endswith('.safetensors'): + shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') + loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained + transformer = loader( + local_file, + cache_dir=shared.opts.hfcache_dir, + **load_args, + ) + transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None) + else: + shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}') + if dtype is not None: + load_args['torch_dtype'] = dtype + if subfolder is not None: + load_args['subfolder'] = subfolder + if variant is not None: + load_args['variant'] = variant + transformer = cls_name.from_pretrained( + repo_id, + cache_dir=shared.opts.hfcache_dir, + **load_args, + **quant_args, + ) + sd_models.allow_post_quant = False # we already handled it + if shared.opts.diffusers_offload_mode != 'none' and transformer is not None: + sd_models.move_model(transformer, devices.cpu) + except Exception as e: + shared.log.error(f'Load model: type=transformer {e}') + if debug: + errors.display(e, 'Load:') + raise return transformer def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder", allow_quant=True, allow_shared=True, variant=None, dtype=None): - load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant) - quant_type = model_quant.get_quant_type(quant_args) text_encoder = None - dtype = dtype or devices.dtype + try: + load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant) + quant_type = model_quant.get_quant_type(quant_args) + dtype = dtype or devices.dtype - # load from local file if specified - local_file = None - if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default': - from modules import model_te - if shared.opts.sd_text_encoder not in list(model_te.te_dict): - shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found') - elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]): - local_file = model_te.te_dict[shared.opts.sd_text_encoder] + # load from local file if specified + local_file = None + if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default': + from modules import model_te + if shared.opts.sd_text_encoder not in list(model_te.te_dict): + shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found') + elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]): + local_file = model_te.te_dict[shared.opts.sd_text_encoder] - # load from local file gguf - if local_file is not None and local_file.lower().endswith('.gguf'): - shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"') - from modules import ggml - ggml.install_gguf() - text_encoder = cls_name.from_pretrained( - gguf_file=local_file, - quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype), - cache_dir=shared.opts.hfcache_dir, - **load_args, - ) - text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None) - # load from local file safetensors - elif local_file is not None and local_file.lower().endswith('.safetensors'): - shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"') - text_encoder = cls_name.from_pretrained( - local_file, - cache_dir=shared.opts.hfcache_dir, - **load_args, - ) - text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None) - # use shared t5 if possible - elif cls_name == transformers.T5EncoderModel and allow_shared: - with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f: - load_args['config'] = transformers.T5Config(**json.load(f)) - if model_quant.check_nunchaku('TE'): - import nunchaku - repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' - cls_name = nunchaku.NunchakuT5EncoderModel - shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"') - text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained( - repo_id, - torch_dtype=dtype, + # load from local file gguf + if local_file is not None and local_file.lower().endswith('.gguf'): + shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"') + from modules import ggml + ggml.install_gguf() + text_encoder = cls_name.from_pretrained( + gguf_file=local_file, + quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype), + cache_dir=shared.opts.hfcache_dir, + **load_args, ) - text_encoder.quantization_method = 'SVDQuant' - elif shared.opts.te_shared_t5: - repo_id = 'Disty0/t5-xxl' + text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None) + # load from local file safetensors + elif local_file is not None and local_file.lower().endswith('.safetensors'): + shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"') + text_encoder = cls_name.from_pretrained( + local_file, + cache_dir=shared.opts.hfcache_dir, + **load_args, + ) + text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None) + # use shared t5 if possible + elif cls_name == transformers.T5EncoderModel and allow_shared: + with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f: + load_args['config'] = transformers.T5Config(**json.load(f)) + if model_quant.check_nunchaku('TE'): + import nunchaku + repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' + cls_name = nunchaku.NunchakuT5EncoderModel + shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"') + text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained( + repo_id, + torch_dtype=dtype, + ) + text_encoder.quantization_method = 'SVDQuant' + elif shared.opts.te_shared_t5: + repo_id = 'Disty0/t5-xxl' + shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}') + if dtype is not None: + load_args['torch_dtype'] = dtype + text_encoder = cls_name.from_pretrained( + repo_id, + cache_dir=shared.opts.hfcache_dir, + **load_args, + **quant_args, + ) + + # load from repo + if text_encoder is None: shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}') if dtype is not None: load_args['torch_dtype'] = dtype + if subfolder is not None: + load_args['subfolder'] = subfolder + if variant is not None: + load_args['variant'] = variant text_encoder = cls_name.from_pretrained( repo_id, cache_dir=shared.opts.hfcache_dir, @@ -123,22 +148,12 @@ def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder **quant_args, ) - # load from repo - if text_encoder is None: - shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}') - if dtype is not None: - load_args['torch_dtype'] = dtype - if subfolder is not None: - load_args['subfolder'] = subfolder - if variant is not None: - load_args['variant'] = variant - text_encoder = cls_name.from_pretrained( - repo_id, - cache_dir=shared.opts.hfcache_dir, - **load_args, - **quant_args, - ) - - if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None: - sd_models.move_model(text_encoder, devices.cpu) + sd_models.allow_post_quant = False # we already handled it + if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None: + sd_models.move_model(text_encoder, devices.cpu) + except Exception as e: + shared.log.error(f'Load model: type=te {e}') + if debug: + errors.display(e, 'Load:') + raise return text_encoder diff --git a/pipelines/meissonic/test.py b/pipelines/meissonic/test.py index 5687cbff0..777f40e22 100644 --- a/pipelines/meissonic/test.py +++ b/pipelines/meissonic/test.py @@ -4,7 +4,7 @@ sys.path.append("./") # import torch # from torchvision import transforms from meissonic.transformer import Transformer2DModel as TransformerMeissonic -from meissonic.pipeline import Pipeline as PipelineMeissonic +from meissonic.pipeline import MeissonicPipeline from meissonic.scheduler import Scheduler as MeissonicScheduler from transformers import CLIPTextModelWithProjection, CLIPTokenizer from diffusers import VQModel @@ -21,7 +21,7 @@ vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", cache_dir=cach text_encoder = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir) tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") scheduler = MeissonicScheduler.from_pretrained(model_path, subfolder="scheduler") -pipe = PipelineMeissonic(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler) +pipe = MeissonicPipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler) pipe = pipe.to(device) steps = 64 diff --git a/pipelines/model_chroma.py b/pipelines/model_chroma.py index b0a28fa69..eb94bfed4 100644 --- a/pipelines/model_chroma.py +++ b/pipelines/model_chroma.py @@ -1,13 +1,9 @@ -import os import diffusers import transformers from modules import shared, devices, sd_models, model_quant from pipelines import generic -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None - - def load_chroma(checkpoint_info, diffusers_load_config={}): repo_id = sd_models.path_to_repo(checkpoint_info) sd_models.hf_auth_check(checkpoint_info) diff --git a/pipelines/model_flux.py b/pipelines/model_flux.py index da4bf70e9..d420ff80e 100644 --- a/pipelines/model_flux.py +++ b/pipelines/model_flux.py @@ -1,360 +1,76 @@ -import os -import json -import torch import diffusers import transformers -from safetensors.torch import load_file -from huggingface_hub import hf_hub_download -from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te +from modules import shared, devices, sd_models, model_quant +from pipelines import generic -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None - - -def load_flux_quanto(checkpoint_info): - transformer, text_encoder_2 = None, None - quanto = model_quant.load_quanto('Load model: type=FLUX') - - if isinstance(checkpoint_info, str): - repo_path = checkpoint_info - else: - repo_path = checkpoint_info.path - - try: - quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json") - debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"') - if not os.path.exists(quantization_map): - repo_id = sd_models.path_to_repo(checkpoint_info) - quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) - with open(quantization_map, "r", encoding='utf8') as f: - quantization_map = json.load(f) - state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors")) - dtype = state_dict['context_embedder.bias'].dtype - with torch.device("meta"): - transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype) - quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu")) - transformer_dtype = transformer.dtype - if transformer_dtype != devices.dtype: - try: - transformer = transformer.to(dtype=devices.dtype) - except Exception: - shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}") - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}") - if debug: - errors.display(e, 'FLUX Quanto:') - - try: - quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json") - debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"') - if not os.path.exists(quantization_map): - repo_id = sd_models.path_to_repo(checkpoint_info) - quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir) - with open(quantization_map, "r", encoding='utf8') as f: - quantization_map = json.load(f) - with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f: - t5_config = transformers.T5Config(**json.load(f)) - state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors")) - dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype - with torch.device("meta"): - text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype) - quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu")) - text_encoder_2_dtype = text_encoder_2.dtype - if text_encoder_2_dtype != devices.dtype: - try: - text_encoder_2 = text_encoder_2.to(dtype=devices.dtype) - except Exception: - shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}") - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}") - if debug: - errors.display(e, 'FLUX Quanto:') - - return transformer, text_encoder_2 - - -def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument - transformer, text_encoder_2 = None, None - if isinstance(checkpoint_info, str): - repo_path = checkpoint_info - else: - repo_path = checkpoint_info.path - model_quant.load_bnb('Load model: type=FLUX') - quant = model_quant.get_quant(repo_path) - try: - if quant == 'fp8': - quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype) - debug(f'Quantization: {quantization_config}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) - elif quant == 'fp4': - quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4') - debug(f'Quantization: {quantization_config}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) - elif quant == 'nf4': - quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4') - debug(f'Quantization: {quantization_config}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config) - else: - transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config) - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}") - transformer, text_encoder_2 = None, None - if debug: - errors.display(e, 'FLUX:') - return transformer, text_encoder_2 - - -def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument - try: - diffusers_load_config = { - "torch_dtype": devices.dtype, - "cache_dir": cache_dir, - } - if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'): - import nunchaku - nunchaku_precision = nunchaku.utils.get_precision() - nunchaku_repo = None - if 'flux.1-kontext' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors" - elif 'flux.1-dev' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors" - elif 'flux.1-schnell' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" - elif 'flux.1-fill' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" - elif 'flux.1-depth' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" - elif 'shuttle' in repo_id.lower(): - nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors" - else: - shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported') - if nunchaku_repo is not None: - shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}') - kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype) - kwargs['transformer'].quantization_method = 'SVDQuant' - if shared.opts.nunchaku_attention: - kwargs['transformer'].set_attention_impl("nunchaku-fp16") - if 'transformer' not in kwargs and model_quant.check_quant('Model'): - load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True) - kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args) - if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'): - import nunchaku - nunchaku_precision = nunchaku.utils.get_precision() - nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' - shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}') - kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype) - kwargs['text_encoder_2'].quantization_method = 'SVDQuant' - if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'): - load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True) - kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args) - except Exception as e: - shared.log.error(f'Quantization: {e}') - errors.display(e, 'Quantization:') - return kwargs - - -def load_transformer(file_path): # triggered by opts.sd_unet change - if file_path is None or not os.path.exists(file_path): - return None - transformer = None - quant = model_quant.get_quant(file_path) - diffusers_load_config = { - "torch_dtype": devices.dtype, - "cache_dir": shared.opts.hfcache_dir, - } - if quant is not None and quant != 'none': - shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}') - if 'gguf' in file_path.lower(): - from modules import ggml - _transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype) - if _transformer is not None: - transformer = _transformer - elif quant == "fp8": - _transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config) - if _transformer is not None: - transformer = _transformer - elif quant in {'qint8', 'qint4'}: - _transformer, _text_encoder_2 = load_flux_quanto(file_path) - if _transformer is not None: - transformer = _transformer - elif quant in {'fp8', 'fp4', 'nf4'}: - _transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config) - if _transformer is not None: - transformer = _transformer - elif 'nf4' in quant: - from pipelines.model_flux_nf4 import load_flux_nf4 - _transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True) - if _transformer is not None: - transformer = _transformer - else: - quant_args = model_quant.create_bnb_config({}) - if quant_args: - shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}') - from pipelines.model_flux_nf4 import load_flux_nf4 - transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False) - if transformer is not None: - return transformer - load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True) - shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args) - if transformer is None: - shared.log.error('Failed to load UNet model') - shared.opts.sd_unet = 'Default' - return transformer - - -def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change +def load_flux(checkpoint_info, diffusers_load_config={}): repo_id = sd_models.path_to_repo(checkpoint_info) sd_models.hf_auth_check(checkpoint_info) - allow_post_quant = False - prequantized = model_quant.get_quant(checkpoint_info.path) - shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}') - debug(f'Load model: type=FLUX config={diffusers_load_config}') + if 'Fill' in repo_id: + cls_name = diffusers.FluxFillPipeline + elif 'Canny' in repo_id: + cls_name = diffusers.FluxControlPipeline + elif 'Depth' in repo_id: + cls_name = diffusers.FluxControlPipeline + elif 'Kontext' in repo_id: + cls_name = diffusers.FluxKontextPipeline + diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline + diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline + diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline + else: + cls_name = diffusers.FluxPipeline - transformer = None - text_encoder_1 = None - text_encoder_2 = None - vae = None + load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False) + shared.log.debug(f'Load model: type=Flux repo="{repo_id}" cls={cls_name.__name__} config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}') - # unload current model - sd_models.unload_model_weights() - shared.sd_model = None - devices.torch_gc(force=True, reason='load') - - if shared.opts.teacache_enabled: + # optional teacache patch + if shared.opts.teacache_enabled and not model_quant.check_nunchaku('Model'): from modules import teacache shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}') diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded - # load overrides if any - if shared.opts.sd_unet != 'Default': - try: - debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"') - transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet]) - if transformer is None: - shared.opts.sd_unet = 'Default' - sd_unet.failed_unet.append(shared.opts.sd_unet) - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}") - shared.opts.sd_unet = 'Default' - if debug: - errors.display(e, 'FLUX UNet:') - if shared.opts.sd_text_encoder != 'Default': - try: - debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"') - from modules.model_te import load_t5, load_vit_l - if 'vit-l' in shared.opts.sd_text_encoder.lower(): - text_encoder_1 = load_vit_l() - else: - text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load T5: {e}") - shared.opts.sd_text_encoder = 'Default' - if debug: - errors.display(e, 'FLUX T5:') - if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic': - try: - debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"') - from modules import sd_vae - # vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override') - vae_file = sd_vae.vae_dict[shared.opts.sd_vae] - if os.path.exists(vae_file): - vae_config = os.path.join('configs', 'flux', 'vae', 'config.json') - vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config) - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}") - shared.opts.sd_vae = 'Default' - if debug: - errors.display(e, 'FLUX VAE:') + transformer = None + text_encoder_2 = None - # load quantized components if any - if prequantized == 'nf4': - try: - from pipelines.model_flux_nf4 import load_flux_nf4 - _transformer, _text_encoder = load_flux_nf4(checkpoint_info) - if _transformer is not None: - transformer = _transformer - if _text_encoder is not None: - text_encoder_2 = _text_encoder - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}") - if debug: - errors.display(e, 'FLUX NF4:') - if prequantized == 'qint8' or prequantized == 'qint4': - try: - _transformer, _text_encoder = load_flux_quanto(checkpoint_info) - if _transformer is not None: - transformer = _transformer - if _text_encoder is not None: - text_encoder_2 = _text_encoder - except Exception as e: - shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}") - if debug: - errors.display(e, 'FLUX Quanto:') + # handle transformer svdquant if available, t5 is handled inside load_text_encoder + prequantized = model_quant.get_quant(checkpoint_info.path) + if model_quant.check_nunchaku('Model'): + from pipelines.flux.flux_nunchaku import load_flux_nunchaku + transformer = load_flux_nunchaku(repo_id) + # handle prequantized models + elif prequantized == 'nf4': + from pipelines.flux.flux_nf4 import load_flux_nf4 + transformer, text_encoder_2 = load_flux_nf4(checkpoint_info) + elif prequantized == 'qint8' or prequantized == 'qint4': + from pipelines.flux.flux_quanto import load_flux_quanto + transformer, text_encoder_2 = load_flux_quanto(checkpoint_info) + elif prequantized == 'fp4' or prequantized == 'fp8': + from pipelines.flux.flux_bnb import load_flux_bnb + transformer = load_flux_bnb(checkpoint_info, diffusers_load_config) - # initialize pipeline with pre-loaded components - kwargs = {} - if transformer is not None: - kwargs['transformer'] = transformer - sd_unet.loaded_unet = shared.opts.sd_unet - if text_encoder_1 is not None: - kwargs['text_encoder'] = text_encoder_1 - model_te.loaded_te = shared.opts.sd_text_encoder - if text_encoder_2 is not None: - kwargs['text_encoder_2'] = text_encoder_2 - model_te.loaded_te = shared.opts.sd_text_encoder - if vae is not None: - kwargs['vae'] = vae - if repo_id == 'sayakpaul/flux.1-dev-nf4': - repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json - if 'Fill' in repo_id: - cls = diffusers.FluxFillPipeline - elif 'Canny' in repo_id: - cls = diffusers.FluxControlPipeline - elif 'Depth' in repo_id: - cls = diffusers.FluxControlPipeline - elif 'Kontext' in repo_id: - cls = diffusers.FluxKontextPipeline - from diffusers import pipelines - pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline - pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline - pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline + # finally load transformer and text encoder if not already loaded + if transformer is None: + transformer = generic.load_transformer(repo_id, cls_name=diffusers.FluxTransformer2DModel, load_config=diffusers_load_config) + if text_encoder_2 is None: + text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config) - else: - cls = diffusers.FluxPipeline - shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}') - for c in kwargs: - if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None: - shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}') - if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32: - try: - kwargs[c] = kwargs[c].to(dtype=devices.dtype) - shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast') - except Exception: - pass + pipe = cls_name.from_pretrained( + repo_id, + transformer=transformer, + text_encoder_2=text_encoder_2, + cache_dir=shared.opts.diffusers_dir, + **load_args, + ) - allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none') - fn = checkpoint_info.path - if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): - kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant) - if fn.endswith('.safetensors') and os.path.isfile(fn): - pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) - allow_post_quant = True - else: - pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) + del text_encoder_2 + del transformer + # optional first-block patch if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'): from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe apply_cache_on_pipe(pipe, residual_diff_threshold=0.12) - # release memory - transformer = None - text_encoder_1 = None - text_encoder_2 = None - vae = None - for k in kwargs.keys(): - kwargs[k] = None - sd_hijack_te.init_hijack(pipe) devices.torch_gc(force=True, reason='load') - return pipe, allow_post_quant + return pipe diff --git a/pipelines/model_hunyuandit.py b/pipelines/model_hunyuandit.py index 39d51a560..87b74eca8 100644 --- a/pipelines/model_hunyuandit.py +++ b/pipelines/model_hunyuandit.py @@ -12,7 +12,8 @@ def load_hunyuandit(checkpoint_info, diffusers_load_config={}): shared.log.debug(f'Load model: type=HunyuanDiT repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}') transformer = generic.load_transformer(repo_id, cls_name=diffusers.HunyuanDiT2DModel, load_config=diffusers_load_config) - text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2") + repo_te = 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers' if 'HunyuanDiT-v1' in repo_id else repo_id + text_encoder_2 = generic.load_text_encoder(repo_te, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2", allow_shared=False) # this is not normal t5 pipe = diffusers.HunyuanDiTPipeline.from_pretrained( repo_id, diff --git a/pipelines/model_kolors.py b/pipelines/model_kolors.py index 8add20664..26fcc8497 100644 --- a/pipelines/model_kolors.py +++ b/pipelines/model_kolors.py @@ -2,21 +2,14 @@ import torch import diffusers -repo_id = 'Kwai-Kolors/Kolors-diffusers' - - def load_kolors(_checkpoint_info, diffusers_load_config={}): from modules import shared, devices diffusers_load_config['variant'] = "fp16" if 'torch_dtype' not in diffusers_load_config: diffusers_load_config['torch_dtype'] = torch.float16 - # import torch - # import transformers - # encoder_id = 'THUDM/chatglm3-6b' - # text_encoder = transformers.AutoModel.from_pretrained(encoder_id, torch_dtype=torch.float16, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir) - # text_encoder = transformers.AutoModel.from_pretrained("THUDM/chatglm3-6b", torch_dtype=torch.float16, trust_remote_code=True).quantize(4).cuda() - # tokenizer = transformers.AutoTokenizer.from_pretrained(encoder_id, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir) + repo_id = 'Kwai-Kolors/Kolors-diffusers' + shared.log.debug(f'Load model: type=Kolors repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}') pipe = diffusers.KolorsPipeline.from_pretrained( repo_id, cache_dir = shared.opts.diffusers_dir, diff --git a/pipelines/model_lumina.py b/pipelines/model_lumina.py index 60b681881..f29104fdb 100644 --- a/pipelines/model_lumina.py +++ b/pipelines/model_lumina.py @@ -1,9 +1,9 @@ import os import transformers import diffusers -from modules import errors, shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant +from modules import shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant + -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None def load_lumina(_checkpoint_info, diffusers_load_config={}): @@ -30,7 +30,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}): load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model') if shared.opts.sd_unet != 'Default': try: - debug(f'Load model: type=Lumina2 unet="{shared.opts.sd_unet}"') transformer = diffusers.Lumina2Transformer2DModel.from_single_file( sd_unet.unet_dict[shared.opts.sd_unet], cache_dir=shared.opts.diffusers_dir, @@ -43,12 +42,9 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}): except Exception as e: shared.log.error(f"Load model: type=Lumina2 failed to load UNet: {e}") shared.opts.sd_unet = 'Default' - if debug: - errors.display(e, 'Lumina2 UNet:') if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic': try: - debug(f'Load model: type=Lumina2 vae="{shared.opts.sd_vae}"') from modules import sd_vae # vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override') vae_file = sd_vae.vae_dict[shared.opts.sd_vae] @@ -58,8 +54,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}): except Exception as e: shared.log.error(f"Load model: type=Lumina2 failed to load VAE: {e}") shared.opts.sd_vae = 'Default' - if debug: - errors.display(e, 'Lumina2 VAE:') if transformer is None: transformer = diffusers.Lumina2Transformer2DModel.from_pretrained( diff --git a/pipelines/model_meissonic.py b/pipelines/model_meissonic.py index 30671e350..b045f006c 100644 --- a/pipelines/model_meissonic.py +++ b/pipelines/model_meissonic.py @@ -6,10 +6,10 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}): from modules import shared, devices, modelloader, sd_models, shared_items from pipelines.meissonic.transformer import Transformer2DModel as TransformerMeissonic from pipelines.meissonic.scheduler import Scheduler as MeissonicScheduler - from pipelines.meissonic.pipeline import Pipeline as PipelineMeissonic - from pipelines.meissonic.pipeline_img2img import Img2ImgPipeline as PipelineMeissonicImg2Img - from pipelines.meissonic.pipeline_inpaint import InpaintPipeline as PipelineMeissonicInpaint - shared_items.pipelines['Meissonic'] = PipelineMeissonic + from pipelines.meissonic.pipeline import MeissonicPipeline + from pipelines.meissonic.pipeline_img2img import MeissonicImg2ImgPipeline + from pipelines.meissonic.pipeline_inpaint import MeissonicInpaintPipeline + shared_items.pipelines['Meissonic'] = MeissonicPipeline modelloader.hf_login() fn = sd_models.path_to_repo(checkpoint_info) @@ -41,7 +41,7 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}): cache_dir=cache_dir, ) scheduler = MeissonicScheduler.from_pretrained(fn, subfolder="scheduler", cache_dir=cache_dir) - pipe = PipelineMeissonic( + pipe = MeissonicPipeline( vqvae=vqvae.to(devices.dtype), text_encoder=text_encoder.to(devices.dtype), transformer=model.to(devices.dtype), @@ -49,8 +49,8 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}): scheduler=scheduler, ) - diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonic - diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicImg2Img - diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicInpaint + diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicPipeline + diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicImg2ImgPipeline + diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = MeissonicInpaintPipeline devices.torch_gc(force=True, reason='load') return pipe diff --git a/pipelines/model_omnigen.py b/pipelines/model_omnigen.py index 596fe4dbb..b8e8d7fd0 100644 --- a/pipelines/model_omnigen.py +++ b/pipelines/model_omnigen.py @@ -1,9 +1,6 @@ -import os import diffusers from modules import shared, devices, sd_models, model_quant -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None - def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument repo_id = sd_models.path_to_repo(checkpoint_info) diff --git a/pipelines/model_omnigen2.py b/pipelines/model_omnigen2.py index 94488ae0a..6f2f48e7b 100644 --- a/pipelines/model_omnigen2.py +++ b/pipelines/model_omnigen2.py @@ -1,8 +1,5 @@ -import os from modules import shared, devices, sd_models, model_quant -debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None - def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument repo_id = sd_models.path_to_repo(checkpoint_info)