mirror of https://github.com/vladmandic/automatic
remove legacy quant loaders
Signed-off-by: vladmandic <mandic00@live.com>pull/4706/head^2
parent
acf475ee45
commit
53839e464c
11
CHANGELOG.md
11
CHANGELOG.md
|
|
@ -1,8 +1,8 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-03-23
|
||||
## Update for 2026-03-24
|
||||
|
||||
### Highlights for 2026-03-23
|
||||
### Highlights for 2026-03-4
|
||||
|
||||
This release brings massive code refactoring to modernize codebase and removal of some obsolete features. Leaner & Faster!
|
||||
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and new *Anima-Preview*
|
||||
|
|
@ -18,7 +18,7 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
|
||||
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
|
||||
|
||||
### Details for 2026-03-23
|
||||
### Details for 2026-03-24
|
||||
|
||||
- **Models**
|
||||
- [Google Flash 3.1 Image](https://ai.google.dev/gemini-api/docs/models/gemini-3-flash-preview) a.k.a. *Nano Banana 2*
|
||||
|
|
@ -70,6 +70,8 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
use following before first startup to force installation of `torch==2.9.1` with `cuda==12.6`:
|
||||
> `set TORCH_COMMAND='torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu126'`
|
||||
- **UI**
|
||||
- legacy panels **T2I** and **I2I** are disabled by default
|
||||
you can re-enable them in *settings -> ui -> hide legacy tabs*
|
||||
- new panel: **Server Info** with detailed runtime informaton
|
||||
- **Networks** add **UNet/DiT**
|
||||
- **Localization** improved translation quality and new translations locales:
|
||||
|
|
@ -92,6 +94,9 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
- new `/sdapi/v1/rembg` endpoint for background removal
|
||||
- new `/sdadpi/v1/unet` endpoint to list available unets/dits
|
||||
- use rate limiting for api logging
|
||||
- **Obsoleted**
|
||||
- removed support for additional quantization engines: *BitsAndBytes, TorchAO, Optimum-Quanto, NNCF*
|
||||
*note*: SDNQ is quantization engine of choice for SD.Next
|
||||
- **Internal**
|
||||
- `python==3.13` full support
|
||||
- `python==3.14` initial support
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 0861ae00f2ad057a914ca82e45fe6635dde7417e
|
||||
Subproject commit 9d584a1bdc0c2aca614aa0e1e34e4374c3aa779d
|
||||
|
|
@ -706,7 +706,6 @@ def install_openvino():
|
|||
|
||||
if not (args.skip_all or args.skip_requirements):
|
||||
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
|
||||
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
|
||||
ts('openvino', t_start)
|
||||
return torch_command
|
||||
|
||||
|
|
@ -730,10 +729,6 @@ def install_torch_addons():
|
|||
install('DeepCache')
|
||||
if opts.get('cuda_compile_backend', '') == 'olive-ai':
|
||||
install('olive-ai')
|
||||
if len(opts.get('optimum_quanto_weights', [])):
|
||||
install('optimum-quanto==0.2.7', 'optimum-quanto')
|
||||
if len(opts.get('torchao_quantization', [])):
|
||||
install('torchao==0.10.0', 'torchao')
|
||||
if opts.get('samples_format', 'jpg') == 'jxl' or opts.get('grid_format', 'jpg') == 'jxl':
|
||||
install('pillow-jxl-plugin==1.3.7', 'pillow-jxl-plugin')
|
||||
if not args.experimental:
|
||||
|
|
@ -1189,9 +1184,6 @@ def install_optional():
|
|||
install('open-clip-torch', no_deps=True, quiet=True)
|
||||
install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=True, quiet=True)
|
||||
# install('git+https://github.com/openai/CLIP.git', 'clip', quiet=True, no_build_isolation=True)
|
||||
# install('torchao==0.10.0', ignore=True, quiet=True)
|
||||
# install('bitsandbytes==0.47.0', ignore=True, quiet=True)
|
||||
# install('optimum-quanto==0.2.7', ignore=True, quiet=True)
|
||||
ts('optional', t_start)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -382,22 +382,6 @@ class ControlNet():
|
|||
self.model = sdnq_quantize_model(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}')
|
||||
elif "Control" in opts.optimum_quanto_weights:
|
||||
try:
|
||||
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
|
||||
model_quant.load_quanto('Load model: type=Control')
|
||||
from modules.model_quant import optimum_quanto_model
|
||||
self.model = optimum_quanto_model(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}')
|
||||
elif "Control" in opts.torchao_quantization:
|
||||
try:
|
||||
log.debug(f'Control {what} model Torch AO: id="{model_id}"')
|
||||
model_quant.load_torchao('Load model: type=Control')
|
||||
from modules.model_quant import torchao_quantization
|
||||
self.model = torchao_quantization(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model Torch AO: id="{model_id}" {e}')
|
||||
if self.device is not None:
|
||||
sd_models.move_model(self.model, self.device)
|
||||
if "Control" in opts.cuda_compile:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import nncf
|
||||
|
||||
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
|
||||
from openvino.frontend import FrontEndManager
|
||||
from openvino import Core, Type, PartialShape, serialize
|
||||
from openvino.properties import hint as ov_hints
|
||||
from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module
|
||||
from openvino import Core, Type, PartialShape, serialize # pylint: disable=no-name-in-module
|
||||
from openvino.properties import hint as ov_hints # pylint: disable=no-name-in-module
|
||||
|
||||
from torch._dynamo.backends.common import fake_tensor_unsupported
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
|
|
@ -38,6 +36,7 @@ except Exception:
|
|||
|
||||
try:
|
||||
# silence the pytorch version warning
|
||||
import nncf
|
||||
nncf.common.logging.logger.warn_bkc_version_mismatch = lambda *args, **kwargs: None
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -215,8 +214,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
|
|||
core = Core()
|
||||
|
||||
device = get_device()
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
|
|
@ -259,26 +256,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
|
|||
om.inputs[idx-idx_minus].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if shared.opts.nncf_quantize and not dont_use_quant:
|
||||
new_inputs = []
|
||||
for idx, _ in enumerate(example_inputs):
|
||||
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
|
||||
new_inputs = [new_inputs]
|
||||
if shared.opts.nncf_quantize_mode == "INT8":
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs))
|
||||
else:
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
|
||||
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
|
||||
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
|
||||
|
||||
if shared.opts.nncf_compress_weights and not dont_use_nncf:
|
||||
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
|
||||
om = nncf.compress_weights(om)
|
||||
else:
|
||||
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
|
||||
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
|
||||
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
|
||||
|
||||
hints = {}
|
||||
if shared.opts.openvino_accuracy == "performance":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
|
||||
|
|
@ -287,9 +264,7 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non
|
|||
if model_hash_str is not None:
|
||||
hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob'
|
||||
core.set_property(hints)
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_4bit_nncf = False
|
||||
|
||||
compiled_model = core.compile_model(om, device)
|
||||
return compiled_model
|
||||
|
|
@ -299,8 +274,6 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
|
|||
core = Core()
|
||||
om = core.read_model(cached_model_path + ".xml")
|
||||
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
|
||||
for idx, input_data in enumerate(example_inputs):
|
||||
|
|
@ -308,35 +281,13 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
|
|||
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if shared.opts.nncf_quantize and not dont_use_quant:
|
||||
new_inputs = []
|
||||
for idx, _ in enumerate(example_inputs):
|
||||
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
|
||||
new_inputs = [new_inputs]
|
||||
if shared.opts.nncf_quantize_mode == "INT8":
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs))
|
||||
else:
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
|
||||
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
|
||||
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
|
||||
|
||||
if shared.opts.nncf_compress_weights and not dont_use_nncf:
|
||||
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
|
||||
om = nncf.compress_weights(om)
|
||||
else:
|
||||
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
|
||||
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
|
||||
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
|
||||
|
||||
hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}
|
||||
if shared.opts.openvino_accuracy == "performance":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
|
||||
elif shared.opts.openvino_accuracy == "accuracy":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY
|
||||
core.set_property(hints)
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_4bit_nncf = False
|
||||
|
||||
compiled_model = core.compile_model(om, get_device())
|
||||
return compiled_model
|
||||
|
|
@ -462,13 +413,9 @@ def get_subgraph_type(tensor):
|
|||
|
||||
@fake_tensor_unsupported
|
||||
def openvino_fx(subgraph, example_inputs, options=None):
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
global subgraph_type
|
||||
|
||||
dont_use_4bit_nncf = False
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_faketensors = False
|
||||
executor_parameters = None
|
||||
|
|
@ -484,9 +431,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
|
|||
subgraph_type[2] is torch.nn.modules.normalization.GroupNorm and
|
||||
subgraph_type[3] is torch.nn.modules.activation.SiLU):
|
||||
|
||||
dont_use_4bit_nncf = True
|
||||
dont_use_nncf = bool("VAE" not in shared.opts.nncf_compress_weights)
|
||||
dont_use_quant = bool("VAE" not in shared.opts.nncf_quantize)
|
||||
pass
|
||||
|
||||
# SD 1.5 / SDXL Text Encoder
|
||||
elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and
|
||||
|
|
@ -495,8 +440,6 @@ def openvino_fx(subgraph, example_inputs, options=None):
|
|||
subgraph_type[3] is torch.nn.modules.linear.Linear):
|
||||
|
||||
dont_use_faketensors = True
|
||||
dont_use_nncf = bool("TE" not in shared.opts.nncf_compress_weights)
|
||||
dont_use_quant = bool("TE" not in shared.opts.nncf_quantize)
|
||||
|
||||
# Create a hash to be used for caching
|
||||
shared.compiled_model_state.model_hash_str = ""
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
import diffusers
|
||||
import transformers
|
||||
from installer import installed, install, setup_logging
|
||||
from installer import install
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
|
|
@ -51,70 +49,6 @@ def dont_quant():
|
|||
return False
|
||||
|
||||
|
||||
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
from modules import shared, devices
|
||||
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
|
||||
load_bnb()
|
||||
if bnb is None:
|
||||
return kwargs
|
||||
bnb_config = diffusers.BitsAndBytesConfig(
|
||||
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
|
||||
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
|
||||
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
|
||||
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
|
||||
bnb_4bit_compute_dtype=devices.dtype,
|
||||
llm_int8_skip_modules=modules_to_not_convert,
|
||||
)
|
||||
log.debug(f'Quantization: module={module} type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
|
||||
if kwargs is None:
|
||||
return bnb_config
|
||||
else:
|
||||
kwargs['quantization_config'] = bnb_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
from modules import shared
|
||||
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
|
||||
torchao = load_torchao()
|
||||
if torchao is None:
|
||||
return kwargs
|
||||
if module in {'TE', 'LLM'}:
|
||||
ao_config = transformers.TorchAoConfig(quant_type=shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
else:
|
||||
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
log.debug(f'Quantization: module={module} type=torchao dtype={shared.opts.torchao_quantization_type}')
|
||||
if kwargs is None:
|
||||
return ao_config
|
||||
else:
|
||||
kwargs['quantization_config'] = ao_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
|
||||
load_quanto(silent=True)
|
||||
if optimum_quanto is None:
|
||||
return kwargs
|
||||
if module in {'TE', 'LLM'}:
|
||||
quanto_config = transformers.QuantoConfig(weights=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
quanto_config.weights_dtype = quanto_config.weights
|
||||
else:
|
||||
quanto_config = diffusers.QuantoConfig(weights_dtype=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
quanto_config.activations = None # patch so it works with transformers
|
||||
quanto_config.weights = quanto_config.weights_dtype
|
||||
log.debug(f'Quantization: module={module} type=quanto dtype={shared.opts.quanto_quantization_type}')
|
||||
if kwargs is None:
|
||||
return quanto_config
|
||||
else:
|
||||
kwargs['quantization_config'] = quanto_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.trt_quantization):
|
||||
|
|
@ -249,7 +183,7 @@ def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model',
|
|||
|
||||
def check_quant(module: str = ''):
|
||||
from modules import shared
|
||||
if module in shared.opts.sdnq_quantize_weights or module in shared.opts.bnb_quantization or module in shared.opts.torchao_quantization or module in shared.opts.quanto_quantization:
|
||||
if module in shared.opts.sdnq_quantize_weights:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -286,21 +220,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
|
|||
if debug:
|
||||
log.trace(f'Quantization: type=sdnq config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_bnb_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=bnb config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_quanto_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=quanto config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_ao_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=torchao config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_trt_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
|
|
@ -309,88 +228,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
|
|||
return kwargs
|
||||
|
||||
|
||||
def load_torchao(msg='', silent=False):
|
||||
global ao # pylint: disable=global-statement
|
||||
if ao is not None:
|
||||
return ao
|
||||
if not installed('torchao'):
|
||||
install('torchao==0.10.0', quiet=True)
|
||||
log.warning('Quantization: torchao installed please restart')
|
||||
try:
|
||||
import torchao
|
||||
ao = torchao
|
||||
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
from diffusers.utils import import_utils
|
||||
import_utils.is_torchao_available = lambda: True
|
||||
import_utils._torchao_available = True # pylint: disable=protected-access
|
||||
return ao
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import torchao: {e}")
|
||||
ao = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_bnb(msg='', silent=False):
|
||||
from modules import devices
|
||||
global bnb # pylint: disable=global-statement
|
||||
if bnb is not None:
|
||||
return bnb
|
||||
if not installed('bitsandbytes'):
|
||||
if devices.backend == 'cuda':
|
||||
# forcing a version will uninstall the multi-backend-refactor branch of bnb
|
||||
install('bitsandbytes==0.47.0', quiet=True)
|
||||
log.warning('Quantization: bitsandbytes installed please restart')
|
||||
try:
|
||||
import bitsandbytes
|
||||
bnb = bitsandbytes
|
||||
from diffusers.utils import import_utils
|
||||
import_utils._bitsandbytes_available = True # pylint: disable=protected-access
|
||||
import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
|
||||
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
return bnb
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import bitsandbytes: {e}")
|
||||
bnb = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_quanto(msg='', silent=False):
|
||||
global optimum_quanto # pylint: disable=global-statement
|
||||
if optimum_quanto is not None:
|
||||
return optimum_quanto
|
||||
if not installed('optimum-quanto'):
|
||||
install('optimum-quanto==0.2.7', quiet=True)
|
||||
log.warning('Quantization: optimum-quanto installed please restart')
|
||||
try:
|
||||
from optimum import quanto # pylint: disable=no-name-in-module
|
||||
# disable device specific tensors because the model can't be moved between cpu and gpu with them
|
||||
quanto.tensor.weights.qbits.WeightQBitsTensor.create = lambda *args, **kwargs: quanto.tensor.weights.qbits.WeightQBitsTensor(*args, **kwargs)
|
||||
optimum_quanto = quanto
|
||||
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
from diffusers.utils import import_utils
|
||||
import_utils.is_optimum_quanto_available = lambda: True
|
||||
import_utils._optimum_quanto_available = True # pylint: disable=protected-access
|
||||
import_utils._optimum_quanto_version = quanto.__version__ # pylint: disable=protected-access
|
||||
import_utils._replace_with_quanto_layers = diffusers.quantizers.quanto.utils._replace_with_quanto_layers # pylint: disable=protected-access
|
||||
return optimum_quanto
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import optimum.quanto: {e}")
|
||||
optimum_quanto = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_trt(msg='', silent=False):
|
||||
global trt # pylint: disable=global-statement
|
||||
if trt is not None:
|
||||
|
|
@ -642,138 +479,6 @@ def sdnq_quantize_weights(sd_model):
|
|||
return sd_model
|
||||
|
||||
|
||||
def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activations=None):
|
||||
from modules import devices, shared
|
||||
quanto = load_quanto('Quantize model: type=Optimum Quanto')
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
if model.__class__.__name__ in {"FluxTransformer2DModel", "ChromaTransformer2DModel"}: # LayerNorm is not supported
|
||||
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
|
||||
if model.__class__.__name__ == "ChromaTransformer2DModel":
|
||||
# we ignore the distilled guidance layer because it degrades quality too much
|
||||
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
|
||||
exclude_list.append("distilled_guidance_layer.*")
|
||||
elif model.__class__.__name__ == "QwenImageTransformer2DModel":
|
||||
exclude_list = ["transformer_blocks.0.img_mod.1.weight", "time_text_embed", "img_in", "txt_in", "proj_out", "norm_out", "pos_embed"]
|
||||
else:
|
||||
exclude_list = None
|
||||
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)
|
||||
if activations is not None:
|
||||
activations = getattr(quanto, activations) if activations != 'none' else None
|
||||
elif shared.opts.optimum_quanto_activations_type != 'none':
|
||||
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
|
||||
else:
|
||||
activations = None
|
||||
model.eval()
|
||||
backup_embeddings = None
|
||||
if hasattr(model, "get_input_embeddings"):
|
||||
backup_embeddings = copy.deepcopy(model.get_input_embeddings())
|
||||
quanto.quantize(model, weights=weights, activations=activations, exclude=exclude_list)
|
||||
quanto.freeze(model)
|
||||
if hasattr(model, "set_input_embeddings") and backup_embeddings is not None:
|
||||
model.set_input_embeddings(backup_embeddings)
|
||||
if op is not None and shared.opts.optimum_quanto_shuffle_weights:
|
||||
if quant_last_model_name is not None:
|
||||
if "." in quant_last_model_name:
|
||||
last_model_names = quant_last_model_name.split(".")
|
||||
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
|
||||
else:
|
||||
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
if shared.cmd_opts.medvram or shared.cmd_opts.lowvram or shared.opts.diffusers_offload_mode != "none":
|
||||
quant_last_model_name = op
|
||||
quant_last_model_device = model.device
|
||||
else:
|
||||
quant_last_model_name = None
|
||||
quant_last_model_device = None
|
||||
model.to(devices.device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
return model
|
||||
|
||||
|
||||
def optimum_quanto_weights(sd_model):
|
||||
try:
|
||||
t0 = time.time()
|
||||
from modules import shared, devices, sd_models
|
||||
if shared.opts.diffusers_offload_mode in {"balanced", "sequential"}:
|
||||
log.warning(f"Quantization: type=Optimum.quanto offload={shared.opts.diffusers_offload_mode} not compatible")
|
||||
return sd_model
|
||||
log.info(f"Quantization: type=Optimum.quanto: modules={shared.opts.optimum_quanto_weights}")
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
quanto = load_quanto()
|
||||
|
||||
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_model, shared.opts.optimum_quanto_weights, op="optimum-quanto")
|
||||
if quant_last_model_name is not None:
|
||||
if "." in quant_last_model_name:
|
||||
last_model_names = quant_last_model_name.split(".")
|
||||
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
|
||||
else:
|
||||
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
quant_last_model_name = None
|
||||
quant_last_model_device = None
|
||||
|
||||
if shared.opts.optimum_quanto_activations_type != 'none':
|
||||
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
|
||||
else:
|
||||
activations = None
|
||||
|
||||
if activations is not None:
|
||||
def optimum_quanto_freeze(model, op=None, sd_model=None): # pylint: disable=unused-argument
|
||||
quanto.freeze(model)
|
||||
return model
|
||||
if shared.opts.diffusers_offload_mode == "model":
|
||||
sd_model.enable_model_cpu_offload(device=devices.device)
|
||||
if hasattr(sd_model, "encode_prompt"):
|
||||
original_encode_prompt = sd_model.encode_prompt
|
||||
def encode_prompt(*args, **kwargs):
|
||||
embeds = original_encode_prompt(*args, **kwargs)
|
||||
sd_model.maybe_free_model_hooks() # Diffusers keeps the TE on VRAM
|
||||
return embeds
|
||||
sd_model.encode_prompt = encode_prompt
|
||||
else:
|
||||
sd_models.move_model(sd_model, devices.device)
|
||||
with quanto.Calibration(momentum=0.9):
|
||||
sd_model(prompt="dummy prompt", num_inference_steps=10)
|
||||
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_freeze, shared.opts.optimum_quanto_weights, op="optimum-quanto-freeze")
|
||||
if shared.opts.diffusers_offload_mode == "model":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, devices.cpu)
|
||||
if hasattr(sd_model, "encode_prompt"):
|
||||
sd_model.encode_prompt = original_encode_prompt
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
|
||||
t1 = time.time()
|
||||
log.info(f"Quantization: type=Optimum.quanto time={t1-t0:.2f}")
|
||||
except Exception as e:
|
||||
log.warning(f"Quantization: type=Optimum.quanto {e}")
|
||||
return sd_model
|
||||
|
||||
|
||||
def torchao_quantization(sd_model):
|
||||
from modules import shared, devices, sd_models
|
||||
torchao = load_torchao()
|
||||
q = torchao.quantization
|
||||
|
||||
fn = getattr(q, shared.opts.torchao_quantization_type, None)
|
||||
if fn is None:
|
||||
log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported")
|
||||
return sd_model
|
||||
def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
|
||||
q.quantize_(model, fn(), device=devices.device)
|
||||
return model
|
||||
|
||||
log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}")
|
||||
try:
|
||||
t0 = time.time()
|
||||
sd_models.apply_function_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao")
|
||||
t1 = time.time()
|
||||
log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}")
|
||||
except Exception as e:
|
||||
log.error(f"Quantization: type=TorchAO {e}")
|
||||
setup_logging() # torchao uses dynamo which messes with logging so reset is needed
|
||||
return sd_model
|
||||
|
||||
|
||||
def get_dit_args(load_config:dict=None, module:str=None, device_map:bool=False, allow_quant:bool=True, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
from modules import shared, devices
|
||||
config = {} if load_config is None else load_config.copy()
|
||||
|
|
@ -810,12 +515,6 @@ def do_post_load_quant(sd_model, allow=True):
|
|||
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
|
||||
log.debug('Load model: post_quant=sdnq')
|
||||
sd_model = sdnq_quantize_weights(sd_model)
|
||||
if len(shared.opts.optimum_quanto_weights) > 0:
|
||||
log.debug('Load model: post_quant=quanto')
|
||||
sd_model = optimum_quanto_weights(sd_model)
|
||||
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
|
||||
log.debug('Load model: post_quant=torchao')
|
||||
sd_model = torchao_quantization(sd_model)
|
||||
if shared.opts.layerwise_quantization:
|
||||
log.debug('Load model: post_quant=layerwise')
|
||||
apply_layerwise(sd_model)
|
||||
|
|
|
|||
|
|
@ -62,16 +62,6 @@ def load_t5(name=None, cache_dir=None):
|
|||
elif 'fp16' in name.lower():
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'fp4' in name.lower():
|
||||
model_quant.load_bnb('Load model: type=T5')
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'fp8' in name.lower():
|
||||
model_quant.load_bnb('Load model: type=T5')
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'int8' in name.lower():
|
||||
from modules.model_quant import create_sdnq_config
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
|
||||
|
|
@ -84,18 +74,6 @@ def load_t5(name=None, cache_dir=None):
|
|||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'qint4' in name.lower():
|
||||
model_quant.load_quanto('Load model: type=T5')
|
||||
quantization_config = transformers.QuantoConfig(weights='int4')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'qint8' in name.lower():
|
||||
model_quant.load_quanto('Load model: type=T5')
|
||||
quantization_config = transformers.QuantoConfig(weights='int8')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif '/' in name:
|
||||
log.debug(f'Load model: type=T5 repo={name}')
|
||||
quant_config = model_quant.create_config(module='TE')
|
||||
|
|
|
|||
|
|
@ -166,26 +166,6 @@ def create_settings(cmd_opts):
|
|||
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
|
||||
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),
|
||||
|
||||
"bnb_quantization_sep": OptionInfo("<h2>BitsAndBytes</h2>", "", gr.HTML),
|
||||
"bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "VAE"]}),
|
||||
"bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ["nf4", "fp8", "fp4"]}),
|
||||
"bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]}),
|
||||
|
||||
"quanto_quantization_sep": OptionInfo("<h2>Optimum Quanto</h2>", "", gr.HTML),
|
||||
"quanto_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM"]}),
|
||||
"quanto_quantization_type": OptionInfo("int8", "Quantization weights type", gr.Dropdown, {"choices": ["float8", "int8", "int4", "int2"]}),
|
||||
|
||||
"optimum_quanto_sep": OptionInfo("<h2>Optimum Quanto: post-load</h2>", "", gr.HTML),
|
||||
"optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "Control", "VAE"]}),
|
||||
"optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ["qint8", "qfloat8_e4m3fn", "qfloat8_e5m2", "qint4", "qint2"]}),
|
||||
"optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ["none", "qint8", "qfloat8_e4m3fn", "qfloat8_e5m2"]}),
|
||||
"optimum_quanto_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
|
||||
|
||||
"torchao_sep": OptionInfo("<h2>TorchAO</h2>", "", gr.HTML),
|
||||
"torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "Control", "VAE"]}),
|
||||
"torchao_quantization_mode": OptionInfo("auto", "Quantization mode", gr.Dropdown, {"choices": ["auto", "pre", "post"]}),
|
||||
"torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ["int4_weight_only", "int8_dynamic_activation_int4_weight", "int8_weight_only", "int8_dynamic_activation_int8_weight", "float8_weight_only", "float8_dynamic_activation_float8_weight", "float8_static_activation_float8_weight"]}),
|
||||
|
||||
"layerwise_quantization_sep": OptionInfo("<h2>Layerwise Casting</h2>", "", gr.HTML),
|
||||
"layerwise_quantization": OptionInfo([], "Layerwise casting enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
|
||||
"layerwise_quantization_storage": OptionInfo("float8_e4m3fn", "Layerwise casting storage", gr.Dropdown, {"choices": ["float8_e4m3fn", "float8_e5m2"]}),
|
||||
|
|
@ -194,14 +174,6 @@ def create_settings(cmd_opts):
|
|||
"trt_quantization_sep": OptionInfo("<h2>TensorRT</h2>", "", gr.HTML),
|
||||
"trt_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model"]}),
|
||||
"trt_quantization_type": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": ["int8", "int4", "fp8", "nf4", "nvfp4"]}),
|
||||
|
||||
"nncf_compress_sep": OptionInfo("<h2>NNCF: Neural Network Compression Framework</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_mode": OptionInfo("INT8_SYM", "Quantization type", gr.Dropdown, {"choices": ["INT8", "INT8_SYM", "FP8", "MXFP8", "INT4_ASYM", "INT4_SYM", "FP4", "MXFP4", "NF4"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_raito": OptionInfo(0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": cmd_opts.use_openvino}),
|
||||
"nncf_quantize": OptionInfo([], "Static Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_quantize_mode": OptionInfo("INT8", "OpenVINO activations mode", gr.Dropdown, {"choices": ["INT8", "FP8_E4M3", "FP8_E5M2"], "visible": cmd_opts.use_openvino}),
|
||||
}))
|
||||
# --- VAE & Text Encoder ---
|
||||
options_templates.update(options_section(('vae_encoder', "Variational Auto Encoder"), {
|
||||
|
|
|
|||
|
|
@ -1,25 +0,0 @@
|
|||
import diffusers
|
||||
import transformers
|
||||
from modules import devices, model_quant
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
return transformer
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer, text_encoder_2 = None, None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
try:
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
|
||||
transformer, text_encoder_2 = None, None
|
||||
if debug:
|
||||
errors.display(e, 'FLUX:')
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
|
||||
try:
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": cache_dir,
|
||||
}
|
||||
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
if 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
|
||||
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype, cache_dir=cache_dir)
|
||||
kwargs['transformer'].quantization_method = 'SVDQuant'
|
||||
if shared.opts.nunchaku_attention:
|
||||
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
|
||||
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
|
||||
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
|
||||
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
|
||||
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
|
||||
except Exception as e:
|
||||
log.error(f'Quantization: {e}')
|
||||
errors.display(e, 'Quantization:')
|
||||
return kwargs
|
||||
|
||||
|
||||
def load_transformer(file_path): # triggered by opts.sd_unet change
|
||||
if file_path is None or not os.path.exists(file_path):
|
||||
return None
|
||||
transformer = None
|
||||
quant = model_quant.get_quant(file_path)
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": shared.opts.hfcache_dir,
|
||||
}
|
||||
if quant is not None and quant != 'none':
|
||||
log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
|
||||
if 'gguf' in file_path.lower():
|
||||
from modules import ggml
|
||||
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == "fp8":
|
||||
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'qint8', 'qint4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'fp8', 'fp4', 'nf4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
else:
|
||||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
|
||||
if transformer is None:
|
||||
log.error('Failed to load UNet model')
|
||||
shared.opts.sd_unet = 'Default'
|
||||
return transformer
|
||||
|
||||
|
||||
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
allow_post_quant = False
|
||||
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
|
||||
debug(f'Load model: type=FLUX config={diffusers_load_config}')
|
||||
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
|
||||
# unload current model
|
||||
sd_models.unload_model_weights()
|
||||
shared.sd_model = None
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
|
||||
if shared.opts.teacache_enabled:
|
||||
from modules import teacache
|
||||
log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
|
||||
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
|
||||
|
||||
# load overrides if any
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
|
||||
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
|
||||
if transformer is None:
|
||||
shared.opts.sd_unet = 'Default'
|
||||
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX UNet:')
|
||||
if shared.opts.sd_text_encoder != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
|
||||
from modules.model_te import load_t5, load_vit_l
|
||||
if 'vit-l' in shared.opts.sd_text_encoder.lower():
|
||||
text_encoder_1 = load_vit_l()
|
||||
else:
|
||||
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load T5: {e}")
|
||||
shared.opts.sd_text_encoder = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX T5:')
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX VAE:')
|
||||
|
||||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX NF4:')
|
||||
if prequantized == 'qint8' or prequantized == 'qint4':
|
||||
try:
|
||||
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
# initialize pipeline with pre-loaded components
|
||||
kwargs = {}
|
||||
if transformer is not None:
|
||||
kwargs['transformer'] = transformer
|
||||
sd_unet.loaded_unet = shared.opts.sd_unet
|
||||
if text_encoder_1 is not None:
|
||||
kwargs['text_encoder'] = text_encoder_1
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if text_encoder_2 is not None:
|
||||
kwargs['text_encoder_2'] = text_encoder_2
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
if repo_id == 'sayakpaul/flux.1-dev-nf4':
|
||||
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
|
||||
if 'Fill' in repo_id:
|
||||
cls = diffusers.FluxFillPipeline
|
||||
elif 'Canny' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Depth' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Kontext' in repo_id:
|
||||
cls = diffusers.FluxKontextPipeline
|
||||
from diffusers import pipelines
|
||||
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
|
||||
|
||||
else:
|
||||
cls = diffusers.FluxPipeline
|
||||
log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
|
||||
log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
|
||||
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
|
||||
try:
|
||||
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
|
||||
log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
|
||||
fn = checkpoint_info.path
|
||||
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
|
||||
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
|
||||
if fn.endswith('.safetensors') and os.path.isfile(fn):
|
||||
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
else:
|
||||
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
|
||||
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
|
||||
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
|
||||
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
||||
|
||||
# release memory
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
for k in kwargs.keys():
|
||||
kwargs[k] = None
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe, allow_post_quant
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""
|
||||
Copied from: https://github.com/huggingface/diffusers/issues/9165
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.quantizers.quantizers_utils import get_module_from_name
|
||||
from huggingface_hub import hf_hub_download
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
|
||||
import safetensors.torch
|
||||
from modules import shared, devices, model_quant
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(
|
||||
model,
|
||||
method="nf4",
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
if method == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
has_been_replaced = True
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
compute_dtype=devices.dtype,
|
||||
compress_statistics=False,
|
||||
quant_type="nf4",
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bnb_linear(
|
||||
module,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def check_quantized_param(
|
||||
model,
|
||||
param_name: str,
|
||||
) -> bool:
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
# Add here check for loaded components' dtypes once serialization is implemented
|
||||
return True
|
||||
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
|
||||
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
|
||||
# but it would wrongly use uninitialized weight there.
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def create_quantized_param(
|
||||
model,
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict=None,
|
||||
unexpected_keys=None,
|
||||
pre_quantized=False
|
||||
):
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if tensor_name not in module._parameters: # pylint: disable=protected-access
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if tensor_name == "bias":
|
||||
if param_value is None:
|
||||
new_value = old_value.to(target_device)
|
||||
else:
|
||||
new_value = param_value.to(target_device)
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
return
|
||||
|
||||
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
raise ValueError("this function only loads `Linear4bit components`")
|
||||
if (
|
||||
old_value.device == torch.device("meta")
|
||||
and target_device not in ["meta", torch.device("meta")]
|
||||
and param_value is None
|
||||
):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
if pre_quantized:
|
||||
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
|
||||
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
|
||||
quantized_stats = {}
|
||||
for k, v in state_dict.items():
|
||||
# `startswith` to counter for edge cases where `param_name`
|
||||
# substring can be present in multiple places in the `state_dict`
|
||||
if param_name + "." in k and k.startswith(param_name):
|
||||
quantized_stats[k] = v
|
||||
if unexpected_keys is not None and k in unexpected_keys:
|
||||
unexpected_keys.remove(k)
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=param_value,
|
||||
quantized_stats=quantized_stats,
|
||||
requires_grad=False,
|
||||
device=target_device,
|
||||
)
|
||||
else:
|
||||
new_value = param_value.to("cpu")
|
||||
kwargs = old_value.__dict__
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_flux_nf4(checkpoint_info, prequantized: bool = True):
|
||||
transformer = None
|
||||
text_encoder_2 = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
if os.path.exists(repo_path) and os.path.isfile(repo_path):
|
||||
ckpt_path = repo_path
|
||||
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
|
||||
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
|
||||
else:
|
||||
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
|
||||
if 'sayakpaul' in repo_path:
|
||||
converted_state_dict = original_state_dict # already converted
|
||||
else:
|
||||
try:
|
||||
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX Failed to convert UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'FLUX convert:')
|
||||
converted_state_dict = original_state_dict
|
||||
|
||||
with init_empty_weights():
|
||||
from diffusers import FluxTransformer2DModel
|
||||
config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
|
||||
transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype)
|
||||
expected_state_dict_keys = list(transformer.state_dict().keys())
|
||||
|
||||
_replace_with_bnb_linear(transformer, "nf4")
|
||||
|
||||
try:
|
||||
for param_name, param in converted_state_dict.items():
|
||||
if param_name not in expected_state_dict_keys:
|
||||
continue
|
||||
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
|
||||
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
|
||||
param = param.to(devices.dtype)
|
||||
if not check_quantized_param(transformer, param_name):
|
||||
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
|
||||
else:
|
||||
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
|
||||
except Exception as e:
|
||||
transformer, text_encoder_2 = None, None
|
||||
log.error(f"Load model: type=FLUX failed to load UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'FLUX:')
|
||||
|
||||
del original_state_dict
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return transformer, text_encoder_2
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, model_quant
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
|
@ -41,18 +41,6 @@ def load_flux(checkpoint_info, diffusers_load_config=None):
|
|||
transformer = None
|
||||
text_encoder_2 = None
|
||||
|
||||
# handle prequantized models
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
if prequantized == 'nf4':
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
|
||||
elif prequantized == 'qint8' or prequantized == 'qint4':
|
||||
from pipelines.flux.flux_quanto import load_flux_quanto
|
||||
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
|
||||
elif prequantized == 'fp4' or prequantized == 'fp8':
|
||||
from pipelines.flux.flux_bnb import load_flux_bnb
|
||||
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
|
||||
|
||||
# handle transformer svdquant if available, t5 is handled inside load_text_encoder
|
||||
if transformer is None and model_quant.check_nunchaku('Model'):
|
||||
from pipelines.flux.flux_nunchaku import load_flux_nunchaku
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
|
||||
# 5X10U2V2: fixed geometry (5×10 stride-2), no SD conv matches — disabled
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
|
||||
|
|
@ -117,7 +117,7 @@ RDNA2: Dict[str, str] = {
|
|||
# FWD / FWD1X1: FP32/FP16 forward — enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
|
||||
# FWD11X11: requires 11×11 kernel — no SD match — disabled
|
||||
# FWD11X11: requires 11*11 kernel — no SD match — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
|
||||
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs — disabled
|
||||
|
|
|
|||
|
|
@ -230,4 +230,3 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
|
|||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4",
|
||||
]),
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue