remove legacy quant loaders

Signed-off-by: vladmandic <mandic00@live.com>
pull/4706/head^2
vladmandic 2026-03-24 14:48:44 +01:00
parent acf475ee45
commit 53839e464c
16 changed files with 19 additions and 1120 deletions

View File

@ -1,8 +1,8 @@
# Change Log for SD.Next # Change Log for SD.Next
## Update for 2026-03-23 ## Update for 2026-03-24
### Highlights for 2026-03-23 ### Highlights for 2026-03-4
This release brings massive code refactoring to modernize codebase and removal of some obsolete features. Leaner & Faster! This release brings massive code refactoring to modernize codebase and removal of some obsolete features. Leaner & Faster!
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and new *Anima-Preview* And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and new *Anima-Preview*
@ -18,7 +18,7 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
### Details for 2026-03-23 ### Details for 2026-03-24
- **Models** - **Models**
- [Google Flash 3.1 Image](https://ai.google.dev/gemini-api/docs/models/gemini-3-flash-preview) a.k.a. *Nano Banana 2* - [Google Flash 3.1 Image](https://ai.google.dev/gemini-api/docs/models/gemini-3-flash-preview) a.k.a. *Nano Banana 2*
@ -70,6 +70,8 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
use following before first startup to force installation of `torch==2.9.1` with `cuda==12.6`: use following before first startup to force installation of `torch==2.9.1` with `cuda==12.6`:
> `set TORCH_COMMAND='torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu126'` > `set TORCH_COMMAND='torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu126'`
- **UI** - **UI**
- legacy panels **T2I** and **I2I** are disabled by default
you can re-enable them in *settings -> ui -> hide legacy tabs*
- new panel: **Server Info** with detailed runtime informaton - new panel: **Server Info** with detailed runtime informaton
- **Networks** add **UNet/DiT** - **Networks** add **UNet/DiT**
- **Localization** improved translation quality and new translations locales: - **Localization** improved translation quality and new translations locales:
@ -92,6 +94,9 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- new `/sdapi/v1/rembg` endpoint for background removal - new `/sdapi/v1/rembg` endpoint for background removal
- new `/sdadpi/v1/unet` endpoint to list available unets/dits - new `/sdadpi/v1/unet` endpoint to list available unets/dits
- use rate limiting for api logging - use rate limiting for api logging
- **Obsoleted**
- removed support for additional quantization engines: *BitsAndBytes, TorchAO, Optimum-Quanto, NNCF*
*note*: SDNQ is quantization engine of choice for SD.Next
- **Internal** - **Internal**
- `python==3.13` full support - `python==3.13` full support
- `python==3.14` initial support - `python==3.14` initial support

@ -1 +1 @@
Subproject commit 0861ae00f2ad057a914ca82e45fe6635dde7417e Subproject commit 9d584a1bdc0c2aca614aa0e1e34e4374c3aa779d

View File

@ -1658,4 +1658,4 @@
{"id":"","label":"Z values","localized":"","hint":"Separate values for Z axis using commas","ui":"script_xyz_grid_script"}, {"id":"","label":"Z values","localized":"","hint":"Separate values for Z axis using commas","ui":"script_xyz_grid_script"},
{"id":"","label":"Zoe Depth","localized":"","hint":"","ui":"control"} {"id":"","label":"Zoe Depth","localized":"","hint":"","ui":"control"}
] ]
} }

View File

@ -706,7 +706,6 @@ def install_openvino():
if not (args.skip_all or args.skip_requirements): if not (args.skip_all or args.skip_requirements):
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino') install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
ts('openvino', t_start) ts('openvino', t_start)
return torch_command return torch_command
@ -730,10 +729,6 @@ def install_torch_addons():
install('DeepCache') install('DeepCache')
if opts.get('cuda_compile_backend', '') == 'olive-ai': if opts.get('cuda_compile_backend', '') == 'olive-ai':
install('olive-ai') install('olive-ai')
if len(opts.get('optimum_quanto_weights', [])):
install('optimum-quanto==0.2.7', 'optimum-quanto')
if len(opts.get('torchao_quantization', [])):
install('torchao==0.10.0', 'torchao')
if opts.get('samples_format', 'jpg') == 'jxl' or opts.get('grid_format', 'jpg') == 'jxl': if opts.get('samples_format', 'jpg') == 'jxl' or opts.get('grid_format', 'jpg') == 'jxl':
install('pillow-jxl-plugin==1.3.7', 'pillow-jxl-plugin') install('pillow-jxl-plugin==1.3.7', 'pillow-jxl-plugin')
if not args.experimental: if not args.experimental:
@ -1189,9 +1184,6 @@ def install_optional():
install('open-clip-torch', no_deps=True, quiet=True) install('open-clip-torch', no_deps=True, quiet=True)
install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=True, quiet=True) install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=True, quiet=True)
# install('git+https://github.com/openai/CLIP.git', 'clip', quiet=True, no_build_isolation=True) # install('git+https://github.com/openai/CLIP.git', 'clip', quiet=True, no_build_isolation=True)
# install('torchao==0.10.0', ignore=True, quiet=True)
# install('bitsandbytes==0.47.0', ignore=True, quiet=True)
# install('optimum-quanto==0.2.7', ignore=True, quiet=True)
ts('optional', t_start) ts('optional', t_start)

View File

@ -382,22 +382,6 @@ class ControlNet():
self.model = sdnq_quantize_model(self.model) self.model = sdnq_quantize_model(self.model)
except Exception as e: except Exception as e:
log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}') log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}')
elif "Control" in opts.optimum_quanto_weights:
try:
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
model_quant.load_quanto('Load model: type=Control')
from modules.model_quant import optimum_quanto_model
self.model = optimum_quanto_model(self.model)
except Exception as e:
log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}')
elif "Control" in opts.torchao_quantization:
try:
log.debug(f'Control {what} model Torch AO: id="{model_id}"')
model_quant.load_torchao('Load model: type=Control')
from modules.model_quant import torchao_quantization
self.model = torchao_quantization(self.model)
except Exception as e:
log.error(f'Control {what} model Torch AO: id="{model_id}" {e}')
if self.device is not None: if self.device is not None:
sd_models.move_model(self.model, self.device) sd_models.move_model(self.model, self.device)
if "Control" in opts.cuda_compile: if "Control" in opts.cuda_compile:

View File

@ -1,13 +1,11 @@
import os import os
import sys
import torch import torch
import nncf
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend import FrontEndManager from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module
from openvino import Core, Type, PartialShape, serialize from openvino import Core, Type, PartialShape, serialize # pylint: disable=no-name-in-module
from openvino.properties import hint as ov_hints from openvino.properties import hint as ov_hints # pylint: disable=no-name-in-module
from torch._dynamo.backends.common import fake_tensor_unsupported from torch._dynamo.backends.common import fake_tensor_unsupported
from torch._dynamo.backends.registry import register_backend from torch._dynamo.backends.registry import register_backend
@ -38,6 +36,7 @@ except Exception:
try: try:
# silence the pytorch version warning # silence the pytorch version warning
import nncf
nncf.common.logging.logger.warn_bkc_version_mismatch = lambda *args, **kwargs: None nncf.common.logging.logger.warn_bkc_version_mismatch = lambda *args, **kwargs: None
except Exception: except Exception:
pass pass
@ -215,8 +214,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
core = Core() core = Core()
device = get_device() device = get_device()
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant global dont_use_quant
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"): if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
@ -259,26 +256,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
om.inputs[idx-idx_minus].get_node().set_partial_shape(PartialShape(list(input_data.shape))) om.inputs[idx-idx_minus].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types() om.validate_nodes_and_infer_types()
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quantize_mode == "INT8":
om = nncf.quantize(om, nncf.Dataset(new_inputs))
else:
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
hints = {} hints = {}
if shared.opts.openvino_accuracy == "performance": if shared.opts.openvino_accuracy == "performance":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
@ -287,9 +264,7 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
if model_hash_str is not None: if model_hash_str is not None:
hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob' hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob'
core.set_property(hints) core.set_property(hints)
dont_use_nncf = False
dont_use_quant = False dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, device) compiled_model = core.compile_model(om, device)
return compiled_model return compiled_model
@ -299,8 +274,6 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
core = Core() core = Core()
om = core.read_model(cached_model_path + ".xml") om = core.read_model(cached_model_path + ".xml")
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant global dont_use_quant
for idx, input_data in enumerate(example_inputs): for idx, input_data in enumerate(example_inputs):
@ -308,35 +281,13 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape))) om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types() om.validate_nodes_and_infer_types()
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quantize_mode == "INT8":
om = nncf.quantize(om, nncf.Dataset(new_inputs))
else:
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'} hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}
if shared.opts.openvino_accuracy == "performance": if shared.opts.openvino_accuracy == "performance":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
elif shared.opts.openvino_accuracy == "accuracy": elif shared.opts.openvino_accuracy == "accuracy":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY
core.set_property(hints) core.set_property(hints)
dont_use_nncf = False
dont_use_quant = False dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, get_device()) compiled_model = core.compile_model(om, get_device())
return compiled_model return compiled_model
@ -462,13 +413,9 @@ def get_subgraph_type(tensor):
@fake_tensor_unsupported @fake_tensor_unsupported
def openvino_fx(subgraph, example_inputs, options=None): def openvino_fx(subgraph, example_inputs, options=None):
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant global dont_use_quant
global subgraph_type global subgraph_type
dont_use_4bit_nncf = False
dont_use_nncf = False
dont_use_quant = False dont_use_quant = False
dont_use_faketensors = False dont_use_faketensors = False
executor_parameters = None executor_parameters = None
@ -484,9 +431,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
subgraph_type[2] is torch.nn.modules.normalization.GroupNorm and subgraph_type[2] is torch.nn.modules.normalization.GroupNorm and
subgraph_type[3] is torch.nn.modules.activation.SiLU): subgraph_type[3] is torch.nn.modules.activation.SiLU):
dont_use_4bit_nncf = True pass
dont_use_nncf = bool("VAE" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("VAE" not in shared.opts.nncf_quantize)
# SD 1.5 / SDXL Text Encoder # SD 1.5 / SDXL Text Encoder
elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and
@ -495,8 +440,6 @@ def openvino_fx(subgraph, example_inputs, options=None):
subgraph_type[3] is torch.nn.modules.linear.Linear): subgraph_type[3] is torch.nn.modules.linear.Linear):
dont_use_faketensors = True dont_use_faketensors = True
dont_use_nncf = bool("TE" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("TE" not in shared.opts.nncf_quantize)
# Create a hash to be used for caching # Create a hash to be used for caching
shared.compiled_model_state.model_hash_str = "" shared.compiled_model_state.model_hash_str = ""

View File

@ -1,12 +1,10 @@
import os import os
import re import re
import sys import sys
import copy
import json import json
import time import time
import diffusers import diffusers
import transformers from installer import install
from installer import installed, install, setup_logging
from modules.logger import log from modules.logger import log
@ -51,70 +49,6 @@ def dont_quant():
return False return False
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
from modules import shared, devices
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
load_bnb()
if bnb is None:
return kwargs
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype,
llm_int8_skip_modules=modules_to_not_convert,
)
log.debug(f'Quantization: module={module} type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
if kwargs is None:
return bnb_config
else:
kwargs['quantization_config'] = bnb_config
return kwargs
return kwargs
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
from modules import shared
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
torchao = load_torchao()
if torchao is None:
return kwargs
if module in {'TE', 'LLM'}:
ao_config = transformers.TorchAoConfig(quant_type=shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
else:
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
log.debug(f'Quantization: module={module} type=torchao dtype={shared.opts.torchao_quantization_type}')
if kwargs is None:
return ao_config
else:
kwargs['quantization_config'] = ao_config
return kwargs
return kwargs
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
from modules import shared
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
load_quanto(silent=True)
if optimum_quanto is None:
return kwargs
if module in {'TE', 'LLM'}:
quanto_config = transformers.QuantoConfig(weights=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
quanto_config.weights_dtype = quanto_config.weights
else:
quanto_config = diffusers.QuantoConfig(weights_dtype=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
quanto_config.activations = None # patch so it works with transformers
quanto_config.weights = quanto_config.weights_dtype
log.debug(f'Quantization: module={module} type=quanto dtype={shared.opts.quanto_quantization_type}')
if kwargs is None:
return quanto_config
else:
kwargs['quantization_config'] = quanto_config
return kwargs
return kwargs
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None): def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
from modules import shared from modules import shared
if allow and (module == 'any' or module in shared.opts.trt_quantization): if allow and (module == 'any' or module in shared.opts.trt_quantization):
@ -249,7 +183,7 @@ def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model',
def check_quant(module: str = ''): def check_quant(module: str = ''):
from modules import shared from modules import shared
if module in shared.opts.sdnq_quantize_weights or module in shared.opts.bnb_quantization or module in shared.opts.torchao_quantization or module in shared.opts.quanto_quantization: if module in shared.opts.sdnq_quantize_weights:
return True return True
return False return False
@ -286,21 +220,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
if debug: if debug:
log.trace(f'Quantization: type=sdnq config={kwargs.get("quantization_config", None)}') log.trace(f'Quantization: type=sdnq config={kwargs.get("quantization_config", None)}')
return kwargs return kwargs
kwargs = create_bnb_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=bnb config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_quanto_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=quanto config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_ao_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=torchao config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_trt_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert) kwargs = create_trt_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs: if kwargs is not None and 'quantization_config' in kwargs:
if debug: if debug:
@ -309,88 +228,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
return kwargs return kwargs
def load_torchao(msg='', silent=False):
global ao # pylint: disable=global-statement
if ao is not None:
return ao
if not installed('torchao'):
install('torchao==0.10.0', quiet=True)
log.warning('Quantization: torchao installed please restart')
try:
import torchao
ao = torchao
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
from diffusers.utils import import_utils
import_utils.is_torchao_available = lambda: True
import_utils._torchao_available = True # pylint: disable=protected-access
return ao
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import torchao: {e}")
ao = None
if not silent:
raise
return None
def load_bnb(msg='', silent=False):
from modules import devices
global bnb # pylint: disable=global-statement
if bnb is not None:
return bnb
if not installed('bitsandbytes'):
if devices.backend == 'cuda':
# forcing a version will uninstall the multi-backend-refactor branch of bnb
install('bitsandbytes==0.47.0', quiet=True)
log.warning('Quantization: bitsandbytes installed please restart')
try:
import bitsandbytes
bnb = bitsandbytes
from diffusers.utils import import_utils
import_utils._bitsandbytes_available = True # pylint: disable=protected-access
import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
return bnb
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import bitsandbytes: {e}")
bnb = None
if not silent:
raise
return None
def load_quanto(msg='', silent=False):
global optimum_quanto # pylint: disable=global-statement
if optimum_quanto is not None:
return optimum_quanto
if not installed('optimum-quanto'):
install('optimum-quanto==0.2.7', quiet=True)
log.warning('Quantization: optimum-quanto installed please restart')
try:
from optimum import quanto # pylint: disable=no-name-in-module
# disable device specific tensors because the model can't be moved between cpu and gpu with them
quanto.tensor.weights.qbits.WeightQBitsTensor.create = lambda *args, **kwargs: quanto.tensor.weights.qbits.WeightQBitsTensor(*args, **kwargs)
optimum_quanto = quanto
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
from diffusers.utils import import_utils
import_utils.is_optimum_quanto_available = lambda: True
import_utils._optimum_quanto_available = True # pylint: disable=protected-access
import_utils._optimum_quanto_version = quanto.__version__ # pylint: disable=protected-access
import_utils._replace_with_quanto_layers = diffusers.quantizers.quanto.utils._replace_with_quanto_layers # pylint: disable=protected-access
return optimum_quanto
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import optimum.quanto: {e}")
optimum_quanto = None
if not silent:
raise
return None
def load_trt(msg='', silent=False): def load_trt(msg='', silent=False):
global trt # pylint: disable=global-statement global trt # pylint: disable=global-statement
if trt is not None: if trt is not None:
@ -642,138 +479,6 @@ def sdnq_quantize_weights(sd_model):
return sd_model return sd_model
def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activations=None):
from modules import devices, shared
quanto = load_quanto('Quantize model: type=Optimum Quanto')
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
if model.__class__.__name__ in {"FluxTransformer2DModel", "ChromaTransformer2DModel"}: # LayerNorm is not supported
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
if model.__class__.__name__ == "ChromaTransformer2DModel":
# we ignore the distilled guidance layer because it degrades quality too much
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
exclude_list.append("distilled_guidance_layer.*")
elif model.__class__.__name__ == "QwenImageTransformer2DModel":
exclude_list = ["transformer_blocks.0.img_mod.1.weight", "time_text_embed", "img_in", "txt_in", "proj_out", "norm_out", "pos_embed"]
else:
exclude_list = None
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)
if activations is not None:
activations = getattr(quanto, activations) if activations != 'none' else None
elif shared.opts.optimum_quanto_activations_type != 'none':
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
else:
activations = None
model.eval()
backup_embeddings = None
if hasattr(model, "get_input_embeddings"):
backup_embeddings = copy.deepcopy(model.get_input_embeddings())
quanto.quantize(model, weights=weights, activations=activations, exclude=exclude_list)
quanto.freeze(model)
if hasattr(model, "set_input_embeddings") and backup_embeddings is not None:
model.set_input_embeddings(backup_embeddings)
if op is not None and shared.opts.optimum_quanto_shuffle_weights:
if quant_last_model_name is not None:
if "." in quant_last_model_name:
last_model_names = quant_last_model_name.split(".")
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
else:
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
devices.torch_gc(force=True, reason='quanto')
if shared.cmd_opts.medvram or shared.cmd_opts.lowvram or shared.opts.diffusers_offload_mode != "none":
quant_last_model_name = op
quant_last_model_device = model.device
else:
quant_last_model_name = None
quant_last_model_device = None
model.to(devices.device)
devices.torch_gc(force=True, reason='quanto')
return model
def optimum_quanto_weights(sd_model):
try:
t0 = time.time()
from modules import shared, devices, sd_models
if shared.opts.diffusers_offload_mode in {"balanced", "sequential"}:
log.warning(f"Quantization: type=Optimum.quanto offload={shared.opts.diffusers_offload_mode} not compatible")
return sd_model
log.info(f"Quantization: type=Optimum.quanto: modules={shared.opts.optimum_quanto_weights}")
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
quanto = load_quanto()
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_model, shared.opts.optimum_quanto_weights, op="optimum-quanto")
if quant_last_model_name is not None:
if "." in quant_last_model_name:
last_model_names = quant_last_model_name.split(".")
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
else:
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
devices.torch_gc(force=True, reason='quanto')
quant_last_model_name = None
quant_last_model_device = None
if shared.opts.optimum_quanto_activations_type != 'none':
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
else:
activations = None
if activations is not None:
def optimum_quanto_freeze(model, op=None, sd_model=None): # pylint: disable=unused-argument
quanto.freeze(model)
return model
if shared.opts.diffusers_offload_mode == "model":
sd_model.enable_model_cpu_offload(device=devices.device)
if hasattr(sd_model, "encode_prompt"):
original_encode_prompt = sd_model.encode_prompt
def encode_prompt(*args, **kwargs):
embeds = original_encode_prompt(*args, **kwargs)
sd_model.maybe_free_model_hooks() # Diffusers keeps the TE on VRAM
return embeds
sd_model.encode_prompt = encode_prompt
else:
sd_models.move_model(sd_model, devices.device)
with quanto.Calibration(momentum=0.9):
sd_model(prompt="dummy prompt", num_inference_steps=10)
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_freeze, shared.opts.optimum_quanto_weights, op="optimum-quanto-freeze")
if shared.opts.diffusers_offload_mode == "model":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, devices.cpu)
if hasattr(sd_model, "encode_prompt"):
sd_model.encode_prompt = original_encode_prompt
devices.torch_gc(force=True, reason='quanto')
t1 = time.time()
log.info(f"Quantization: type=Optimum.quanto time={t1-t0:.2f}")
except Exception as e:
log.warning(f"Quantization: type=Optimum.quanto {e}")
return sd_model
def torchao_quantization(sd_model):
from modules import shared, devices, sd_models
torchao = load_torchao()
q = torchao.quantization
fn = getattr(q, shared.opts.torchao_quantization_type, None)
if fn is None:
log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported")
return sd_model
def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
q.quantize_(model, fn(), device=devices.device)
return model
log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}")
try:
t0 = time.time()
sd_models.apply_function_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao")
t1 = time.time()
log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}")
except Exception as e:
log.error(f"Quantization: type=TorchAO {e}")
setup_logging() # torchao uses dynamo which messes with logging so reset is needed
return sd_model
def get_dit_args(load_config:dict=None, module:str=None, device_map:bool=False, allow_quant:bool=True, modules_to_not_convert: list = None, modules_dtype_dict: dict = None): def get_dit_args(load_config:dict=None, module:str=None, device_map:bool=False, allow_quant:bool=True, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
from modules import shared, devices from modules import shared, devices
config = {} if load_config is None else load_config.copy() config = {} if load_config is None else load_config.copy()
@ -810,12 +515,6 @@ def do_post_load_quant(sd_model, allow=True):
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')): if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
log.debug('Load model: post_quant=sdnq') log.debug('Load model: post_quant=sdnq')
sd_model = sdnq_quantize_weights(sd_model) sd_model = sdnq_quantize_weights(sd_model)
if len(shared.opts.optimum_quanto_weights) > 0:
log.debug('Load model: post_quant=quanto')
sd_model = optimum_quanto_weights(sd_model)
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
log.debug('Load model: post_quant=torchao')
sd_model = torchao_quantization(sd_model)
if shared.opts.layerwise_quantization: if shared.opts.layerwise_quantization:
log.debug('Load model: post_quant=layerwise') log.debug('Load model: post_quant=layerwise')
apply_layerwise(sd_model) apply_layerwise(sd_model)

View File

@ -62,16 +62,6 @@ def load_t5(name=None, cache_dir=None):
elif 'fp16' in name.lower(): elif 'fp16' in name.lower():
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype) t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp4' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp8' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'int8' in name.lower(): elif 'int8' in name.lower():
from modules.model_quant import create_sdnq_config from modules.model_quant import create_sdnq_config
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8') quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
@ -84,18 +74,6 @@ def load_t5(name=None, cache_dir=None):
if quantization_config is not None: if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype) t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint4' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int4')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint8' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int8')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif '/' in name: elif '/' in name:
log.debug(f'Load model: type=T5 repo={name}') log.debug(f'Load model: type=T5 repo={name}')
quant_config = model_quant.create_config(module='TE') quant_config = model_quant.create_config(module='TE')

View File

@ -166,26 +166,6 @@ def create_settings(cmd_opts):
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox), "nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox), "nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),
"bnb_quantization_sep": OptionInfo("<h2>BitsAndBytes</h2>", "", gr.HTML),
"bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "VAE"]}),
"bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ["nf4", "fp8", "fp4"]}),
"bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]}),
"quanto_quantization_sep": OptionInfo("<h2>Optimum Quanto</h2>", "", gr.HTML),
"quanto_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM"]}),
"quanto_quantization_type": OptionInfo("int8", "Quantization weights type", gr.Dropdown, {"choices": ["float8", "int8", "int4", "int2"]}),
"optimum_quanto_sep": OptionInfo("<h2>Optimum Quanto: post-load</h2>", "", gr.HTML),
"optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "Control", "VAE"]}),
"optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ["qint8", "qfloat8_e4m3fn", "qfloat8_e5m2", "qint4", "qint2"]}),
"optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ["none", "qint8", "qfloat8_e4m3fn", "qfloat8_e5m2"]}),
"optimum_quanto_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
"torchao_sep": OptionInfo("<h2>TorchAO</h2>", "", gr.HTML),
"torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "Control", "VAE"]}),
"torchao_quantization_mode": OptionInfo("auto", "Quantization mode", gr.Dropdown, {"choices": ["auto", "pre", "post"]}),
"torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ["int4_weight_only", "int8_dynamic_activation_int4_weight", "int8_weight_only", "int8_dynamic_activation_int8_weight", "float8_weight_only", "float8_dynamic_activation_float8_weight", "float8_static_activation_float8_weight"]}),
"layerwise_quantization_sep": OptionInfo("<h2>Layerwise Casting</h2>", "", gr.HTML), "layerwise_quantization_sep": OptionInfo("<h2>Layerwise Casting</h2>", "", gr.HTML),
"layerwise_quantization": OptionInfo([], "Layerwise casting enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}), "layerwise_quantization": OptionInfo([], "Layerwise casting enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
"layerwise_quantization_storage": OptionInfo("float8_e4m3fn", "Layerwise casting storage", gr.Dropdown, {"choices": ["float8_e4m3fn", "float8_e5m2"]}), "layerwise_quantization_storage": OptionInfo("float8_e4m3fn", "Layerwise casting storage", gr.Dropdown, {"choices": ["float8_e4m3fn", "float8_e5m2"]}),
@ -194,14 +174,6 @@ def create_settings(cmd_opts):
"trt_quantization_sep": OptionInfo("<h2>TensorRT</h2>", "", gr.HTML), "trt_quantization_sep": OptionInfo("<h2>TensorRT</h2>", "", gr.HTML),
"trt_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model"]}), "trt_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model"]}),
"trt_quantization_type": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": ["int8", "int4", "fp8", "nf4", "nvfp4"]}), "trt_quantization_type": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": ["int8", "int4", "fp8", "nf4", "nvfp4"]}),
"nncf_compress_sep": OptionInfo("<h2>NNCF: Neural Network Compression Framework</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_mode": OptionInfo("INT8_SYM", "Quantization type", gr.Dropdown, {"choices": ["INT8", "INT8_SYM", "FP8", "MXFP8", "INT4_ASYM", "INT4_SYM", "FP4", "MXFP4", "NF4"], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_raito": OptionInfo(0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": cmd_opts.use_openvino}),
"nncf_quantize": OptionInfo([], "Static Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
"nncf_quantize_mode": OptionInfo("INT8", "OpenVINO activations mode", gr.Dropdown, {"choices": ["INT8", "FP8_E4M3", "FP8_E5M2"], "visible": cmd_opts.use_openvino}),
})) }))
# --- VAE & Text Encoder --- # --- VAE & Text Encoder ---
options_templates.update(options_section(('vae_encoder', "Variational Auto Encoder"), { options_templates.update(options_section(('vae_encoder', "Variational Auto Encoder"), {

View File

@ -1,25 +0,0 @@
import diffusers
import transformers
from modules import devices, model_quant
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
return transformer

View File

@ -1,361 +0,0 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
from modules.logger import log
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder_2 = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
try:
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
transformer, text_encoder_2 = None, None
if debug:
errors.display(e, 'FLUX:')
return transformer, text_encoder_2
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
try:
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": cache_dir,
}
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
if 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
else:
log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype, cache_dir=cache_dir)
kwargs['transformer'].quantization_method = 'SVDQuant'
if shared.opts.nunchaku_attention:
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
except Exception as e:
log.error(f'Quantization: {e}')
errors.display(e, 'Quantization:')
return kwargs
def load_transformer(file_path): # triggered by opts.sd_unet change
if file_path is None or not os.path.exists(file_path):
return None
transformer = None
quant = model_quant.get_quant(file_path)
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
if quant is not None and quant != 'none':
log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == "fp8":
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif quant in {'qint8', 'qint4'}:
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant in {'fp8', 'fp4', 'nf4'}:
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = model_quant.create_bnb_config({})
if quant_args:
log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
if transformer is None:
log.error('Failed to load UNet model')
shared.opts.sd_unet = 'Default'
return transformer
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
allow_post_quant = False
prequantized = model_quant.get_quant(checkpoint_info.path)
log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
debug(f'Load model: type=FLUX config={diffusers_load_config}')
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
# unload current model
sd_models.unload_model_weights()
shared.sd_model = None
devices.torch_gc(force=True, reason='load')
if shared.opts.teacache_enabled:
from modules import teacache
log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
# load overrides if any
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
if transformer is None:
shared.opts.sd_unet = 'Default'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'FLUX UNet:')
if shared.opts.sd_text_encoder != 'Default':
try:
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
from modules.model_te import load_t5, load_vit_l
if 'vit-l' in shared.opts.sd_text_encoder.lower():
text_encoder_1 = load_vit_l()
else:
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load T5: {e}")
shared.opts.sd_text_encoder = 'Default'
if debug:
errors.display(e, 'FLUX T5:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
if os.path.exists(vae_file):
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'FLUX VAE:')
# load quantized components if any
if prequantized == 'nf4':
try:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
if debug:
errors.display(e, 'FLUX NF4:')
if prequantized == 'qint8' or prequantized == 'qint4':
try:
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
# initialize pipeline with pre-loaded components
kwargs = {}
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder_1 is not None:
kwargs['text_encoder'] = text_encoder_1
model_te.loaded_te = shared.opts.sd_text_encoder
if text_encoder_2 is not None:
kwargs['text_encoder_2'] = text_encoder_2
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
if 'Fill' in repo_id:
cls = diffusers.FluxFillPipeline
elif 'Canny' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Depth' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Kontext' in repo_id:
cls = diffusers.FluxKontextPipeline
from diffusers import pipelines
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
else:
cls = diffusers.FluxPipeline
log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
fn = checkpoint_info.path
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
if fn.endswith('.safetensors') and os.path.isfile(fn):
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
allow_post_quant = True
else:
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
# release memory
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
for k in kwargs.keys():
kwargs[k] = None
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe, allow_post_quant

View File

@ -1,201 +0,0 @@
"""
Copied from: https://github.com/huggingface/diffusers/issues/9165
"""
import os
import torch
import torch.nn as nn
from transformers.quantizers.quantizers_utils import get_module_from_name
from huggingface_hub import hf_hub_download
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
import safetensors.torch
from modules import shared, devices, model_quant
from modules.logger import log
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
bnb = model_quant.load_bnb('Load model: type=FLUX')
for name, module in model.named_children():
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
compute_dtype=devices.dtype,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
bnb = model_quant.load_bnb('Load model: type=FLUX')
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
bnb = model_quant.load_bnb('Load model: type=FLUX')
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters: # pylint: disable=protected-access
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
def load_flux_nf4(checkpoint_info, prequantized: bool = True):
transformer = None
text_encoder_2 = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
if os.path.exists(repo_path) and os.path.isfile(repo_path):
ckpt_path = repo_path
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
else:
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
original_state_dict = safetensors.torch.load_file(ckpt_path)
if 'sayakpaul' in repo_path:
converted_state_dict = original_state_dict # already converted
else:
try:
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
except Exception as e:
log.error(f"Load model: type=FLUX Failed to convert UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'FLUX convert:')
converted_state_dict = original_state_dict
with init_empty_weights():
from diffusers import FluxTransformer2DModel
config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype)
expected_state_dict_keys = list(transformer.state_dict().keys())
_replace_with_bnb_linear(transformer, "nf4")
try:
for param_name, param in converted_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(devices.dtype)
if not check_quantized_param(transformer, param_name):
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
else:
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
except Exception as e:
transformer, text_encoder_2 = None, None
log.error(f"Load model: type=FLUX failed to load UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'FLUX:')
del original_state_dict
devices.torch_gc(force=True, reason='load')
return transformer, text_encoder_2

View File

@ -1,74 +0,0 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, model_quant
from modules.logger import log
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2

View File

@ -41,18 +41,6 @@ def load_flux(checkpoint_info, diffusers_load_config=None):
transformer = None transformer = None
text_encoder_2 = None text_encoder_2 = None
# handle prequantized models
prequantized = model_quant.get_quant(checkpoint_info.path)
if prequantized == 'nf4':
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
elif prequantized == 'qint8' or prequantized == 'qint4':
from pipelines.flux.flux_quanto import load_flux_quanto
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
elif prequantized == 'fp4' or prequantized == 'fp8':
from pipelines.flux.flux_bnb import load_flux_bnb
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
# handle transformer svdquant if available, t5 is handled inside load_text_encoder # handle transformer svdquant if available, t5 is handled inside load_text_encoder
if transformer is None and model_quant.check_nunchaku('Model'): if transformer is None and model_quant.check_nunchaku('Model'):
from pipelines.flux.flux_nunchaku import load_flux_nunchaku from pipelines.flux.flux_nunchaku import load_flux_nunchaku

View File

@ -98,7 +98,7 @@ RDNA2: Dict[str, str] = {
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1", "MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
# 5X10U2V2: fixed geometry (5×10 stride-2), no SD conv matches — disabled # 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0", "MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled # 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0", "MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
@ -117,7 +117,7 @@ RDNA2: Dict[str, str] = {
# FWD / FWD1X1: FP32/FP16 forward — enabled # FWD / FWD1X1: FP32/FP16 forward — enabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1", "MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1", "MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
# FWD11X11: requires 11×11 kernel — no SD match — disabled # FWD11X11: requires 11*11 kernel — no SD match — disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0", "MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16; # FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
# can produce dtype=float32 output for FP16 inputs — disabled # can produce dtype=float32 output for FP16 inputs — disabled

View File

@ -230,4 +230,3 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4",
]), ]),
] ]