Add native LoRA loading for Flux2/Klein models

Load Flux2/Klein LoRAs as native NetworkModuleLora objects, bypassing
diffusers PEFT. Handles kohya (lora_unet_), AI toolkit (diffusion_model.),
diffusers PEFT (transformer.), and bare BFL key formats with automatic
QKV splitting for double block fused attention weights.

Includes shape validation to reject architecture-mismatched LoRAs early.
Respects lora_force_diffusers setting to fall back to PEFT when needed.
pull/4708/head
CalamitousFelicitousness 2026-03-25 03:48:16 +00:00
parent 10942032a3
commit 18568db41c
3 changed files with 269 additions and 3 deletions

View File

@ -251,7 +251,8 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
infotext(p)
prompt(p)
if has_changed and len(include) == 0: # print only once
log.info(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} method={load_method} mode={"fuse" if shared.opts.lora_fuse_native else "backup"} te={te_multipliers} unet={unet_multipliers} time={l.timer.summary}')
actual_method = 'native' if any(len(n.modules) > 0 for n in l.loaded_networks) else load_method
log.info(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} method={actual_method} mode={"fuse" if shared.opts.lora_fuse_native else "backup"} te={te_multipliers} unet={unet_multipliers} time={l.timer.summary}')
def deactivate(self, p, force=False):
if len(lora_diffusers.diffuser_loaded) > 0 and (shared.opts.lora_force_reload or force):

View File

@ -260,6 +260,8 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
if shared.sd_model_type == 'f2':
from pipelines.flux import flux2_lora
net = flux2_lora.try_load_lokr(name, network_on_disk, lora_scale)
if net is None and not shared.opts.lora_force_diffusers:
net = flux2_lora.try_load_lora(name, network_on_disk, lora_scale)
if net is None:
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
elif lora_method == 'nunchaku':

View File

@ -1,17 +1,19 @@
"""Flux2/Klein-specific LoRA loading.
Handles:
- Bare BFL-format keys in state dicts (adds diffusion_model. prefix for converter)
- Kohya-format LoRA via native module loading (lora_unet_ prefix keys)
- LoKR adapters via native module loading (bypasses diffusers PEFT system)
- Bare BFL-format keys in state dicts (adds diffusion_model. prefix for converter)
Installed via apply_patch() during pipeline loading.
"""
import os
import time
import torch
from modules import shared, sd_models
from modules.logger import log
from modules.lora import network, network_lokr, lora_convert
from modules.lora import network, network_lokr, network_lora, lora_convert
from modules.lora import lora_common as l
@ -38,6 +40,267 @@ F2_QKV_MAP = {
}
# Kohya underscore suffix -> BFL dot suffix (last underscore becomes dot)
# Used to convert kohya key fragments to look up F2_DOUBLE_MAP / F2_QKV_MAP
KOHYA_SUFFIX_MAP = {
'img_attn_proj': 'img_attn.proj',
'txt_attn_proj': 'txt_attn.proj',
'img_attn_qkv': 'img_attn.qkv',
'txt_attn_qkv': 'txt_attn.qkv',
'img_mlp_0': 'img_mlp.0',
'img_mlp_2': 'img_mlp.2',
'txt_mlp_0': 'txt_mlp.0',
'txt_mlp_2': 'txt_mlp.2',
}
def try_load_lora(name, network_on_disk, lora_scale):
"""Try loading a Flux2/Klein LoRA as native modules.
Handles three key formats:
- Kohya: lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
- AI toolkit (BFL): diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight
- Diffusers PEFT: transformer.single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_A.weight
Returns a Network with native modules, or None to fall through to the diffusers path.
"""
t0 = time.time()
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
has_lora = any('.lora_down.' in k or '.lora_up.' in k or '.lora_A.' in k or '.lora_B.' in k for k in state_dict)
if not has_lora:
return None
is_f2_keys = any(
k.startswith(('lora_unet_single_blocks_', 'lora_unet_double_blocks_',
'diffusion_model.single_blocks.', 'diffusion_model.double_blocks.',
'transformer.single_transformer_blocks.', 'transformer.transformer_blocks.'))
for k in state_dict
)
if not is_f2_keys:
return None
net = load_lora_native(name, network_on_disk, state_dict)
if len(net.modules) == 0:
return None
log.debug(f'Network load: type=LoRA name="{name}" native modules={len(net.modules)} scale={lora_scale}')
l.timer.activate += time.time() - t0
return net
def _group_lora_keys(state_dict):
"""Group LoRA state dict keys into (targets, weights_dict) pairs.
Normalizes all three formats into a common structure. Weight keys are
normalized to lora_down.weight / lora_up.weight regardless of input naming.
Returns list of (targets, weights_dict) where targets come from BFL->diffusers mapping.
"""
# Detect format from first relevant key
sample = next((k for k in state_dict if '.lora_' in k), None)
if sample is None:
return []
if sample.startswith('lora_unet_'):
return _group_kohya(state_dict)
elif sample.startswith('diffusion_model.'):
return _group_bfl(state_dict)
elif sample.startswith('transformer.'):
return _group_peft(state_dict)
# Bare BFL keys (no prefix)
if any(k.startswith(p) for k in state_dict for p in BARE_FLUX_PREFIXES):
return _group_bfl(state_dict, prefix='')
return []
def _normalize_weight_key(suffix):
"""lora_A.weight -> lora_down.weight, lora_B.weight -> lora_up.weight"""
return suffix.replace('lora_A.', 'lora_down.').replace('lora_B.', 'lora_up.')
def _group_kohya(state_dict):
"""Group kohya-format keys (lora_unet_ prefix, underscored module names)."""
groups = {}
for key, weight in state_dict.items():
if not key.startswith('lora_unet_'):
continue
base, _, suffix = key.partition('.')
if not suffix:
continue
if base not in groups:
groups[base] = {}
groups[base][_normalize_weight_key(suffix)] = weight
results = []
for base, weights_dict in groups.items():
if 'lora_down.weight' not in weights_dict:
continue
stripped = base[len('lora_unet_'):]
targets = _kohya_key_to_targets(stripped)
if targets:
results.append((targets, weights_dict))
return results
def _group_bfl(state_dict, prefix='diffusion_model.'):
"""Group BFL/AI-toolkit-format keys (dot-separated module names)."""
groups = {}
for key, weight in state_dict.items():
if prefix and not key.startswith(prefix):
continue
stripped = key[len(prefix):]
# Split at lora boundary: double_blocks.0.img_attn.proj.lora_A.weight
for marker in ('.lora_A.', '.lora_B.', '.lora_down.', '.lora_up.', '.alpha', '.dora_scale'):
pos = stripped.find(marker)
if pos != -1:
base = stripped[:pos]
suffix = stripped[pos + 1:] if stripped[pos + 1:] else marker[1:] # handle bare .dora_scale / .alpha
break
else:
continue
if base not in groups:
groups[base] = {}
groups[base][_normalize_weight_key(suffix)] = weight
results = []
for base, weights_dict in groups.items():
if 'lora_down.weight' not in weights_dict:
continue
targets = _bfl_key_to_targets(base)
if targets:
results.append((targets, weights_dict))
return results
def _group_peft(state_dict):
"""Group diffusers PEFT-format keys (transformer. prefix, diffusers module names)."""
groups = {}
for key, weight in state_dict.items():
if not key.startswith('transformer.'):
continue
stripped = key[len('transformer.'):]
for marker in ('.lora_A.', '.lora_B.', '.lora_down.', '.lora_up.', '.alpha'):
pos = stripped.find(marker)
if pos != -1:
module_path = stripped[:pos]
suffix = stripped[pos + 1:]
break
else:
continue
if module_path not in groups:
groups[module_path] = {}
groups[module_path][_normalize_weight_key(suffix)] = weight
results = []
for module_path, weights_dict in groups.items():
if 'lora_down.weight' not in weights_dict:
continue
# Already in diffusers path format — direct target, no mapping needed
results.append(([(module_path, None, None)], weights_dict))
return results
def load_lora_native(name, network_on_disk, state_dict):
"""Load Flux2/Klein LoRA as native modules from any supported key format."""
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
lora_convert.assign_network_names_to_compvis_modules(sd_model)
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
for targets, weights_dict in _group_lora_keys(state_dict):
for module_path, chunk_index, num_chunks in targets:
network_key = "lora_transformer_" + module_path.replace(".", "_")
sd_module = sd_model.network_layer_mapping.get(network_key)
if sd_module is None:
continue
w = {}
if chunk_index is not None:
up = weights_dict['lora_up.weight']
chunks = torch.chunk(up, num_chunks, dim=0)
w['lora_up.weight'] = chunks[chunk_index].contiguous()
w['lora_down.weight'] = weights_dict['lora_down.weight']
else:
w['lora_up.weight'] = weights_dict['lora_up.weight']
w['lora_down.weight'] = weights_dict['lora_down.weight']
# Validate dimensions match the target module
if hasattr(sd_module, 'weight'):
if hasattr(sd_module, 'sdnq_dequantizer'):
mod_shape = sd_module.sdnq_dequantizer.original_shape
else:
mod_shape = sd_module.weight.shape
if w['lora_down.weight'].shape[1] != mod_shape[1] or w['lora_up.weight'].shape[0] != mod_shape[0]:
log.warning(f'Network load: type=LoRA shape mismatch: {network_key} lora={w["lora_down.weight"].shape[1]}x{w["lora_up.weight"].shape[0]} module={mod_shape[1]}x{mod_shape[0]}')
continue
if 'alpha' in weights_dict:
w['alpha'] = weights_dict['alpha']
if 'dora_scale' in weights_dict:
w['dora_scale'] = weights_dict['dora_scale']
nw = network.NetworkWeights(network_key=network_key, sd_key=network_key, w=w, sd_module=sd_module)
net.modules[network_key] = network_lora.NetworkModuleLora(net, nw)
return net
def _kohya_key_to_targets(stripped):
"""Map a stripped kohya key to (diffusers_module_path, chunk_index, num_chunks) targets.
Input examples: 'double_blocks_0_img_attn_proj', 'single_blocks_5_linear1'
"""
targets = []
if stripped.startswith('single_blocks_'):
rest = stripped[len('single_blocks_'):]
idx, _, suffix = rest.partition('_')
if suffix in F2_SINGLE_MAP:
targets.append((f'single_transformer_blocks.{idx}.{F2_SINGLE_MAP[suffix]}', None, None))
elif stripped.startswith('double_blocks_'):
rest = stripped[len('double_blocks_'):]
idx, _, kohya_suffix = rest.partition('_')
bfl_suffix = KOHYA_SUFFIX_MAP.get(kohya_suffix)
if bfl_suffix is None:
return targets
if bfl_suffix in F2_DOUBLE_MAP:
targets.append((f'transformer_blocks.{idx}.{F2_DOUBLE_MAP[bfl_suffix]}', None, None))
elif bfl_suffix in F2_QKV_MAP:
attn_prefix, proj_keys = F2_QKV_MAP[bfl_suffix]
for i, proj_key in enumerate(proj_keys):
targets.append((f'transformer_blocks.{idx}.{attn_prefix}.{proj_key}', i, len(proj_keys)))
return targets
def _bfl_key_to_targets(base):
"""Map a BFL dot-separated key to (diffusers_module_path, chunk_index, num_chunks) targets.
Input examples: 'double_blocks.0.img_attn.proj', 'single_blocks.5.linear1'
Same mapping as LoKR uses.
"""
targets = []
parts = base.split('.')
if len(parts) < 3:
return targets
block_type, block_idx, module_suffix = parts[0], parts[1], '.'.join(parts[2:])
if block_type == 'single_blocks' and module_suffix in F2_SINGLE_MAP:
path = f'single_transformer_blocks.{block_idx}.{F2_SINGLE_MAP[module_suffix]}'
targets.append((path, None, None))
elif block_type == 'double_blocks':
if module_suffix in F2_DOUBLE_MAP:
path = f'transformer_blocks.{block_idx}.{F2_DOUBLE_MAP[module_suffix]}'
targets.append((path, None, None))
elif module_suffix in F2_QKV_MAP:
attn_prefix, proj_keys = F2_QKV_MAP[module_suffix]
for i, proj_key in enumerate(proj_keys):
path = f'transformer_blocks.{block_idx}.{attn_prefix}.{proj_key}'
targets.append((path, i, len(proj_keys)))
return targets
def apply_lora_alphas(state_dict):
"""Bake kohya-format .alpha scaling into lora_down weights and remove alpha keys.