mirror of https://github.com/vladmandic/automatic
Merge pull request #4700 from vladmandic/feat/flux2-klein-lora
add Flux2/Klein LoRA supportpull/4703/head
commit
a6cae50a67
|
|
@ -257,7 +257,11 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
shared.compiled_model_state.lora_model.append(f"{name}:{lora_scale}")
|
||||
lora_method = lora_overrides.get_method(shorthash)
|
||||
if lora_method == 'diffusers':
|
||||
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
|
||||
if shared.sd_model_type == 'f2':
|
||||
from pipelines.flux import flux2_lora
|
||||
net = flux2_lora.try_load_lokr(name, network_on_disk, lora_scale)
|
||||
if net is None:
|
||||
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
|
||||
elif lora_method == 'nunchaku':
|
||||
pass # handled directly from extra_networks_lora.load_nunchaku
|
||||
else:
|
||||
|
|
@ -272,7 +276,11 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
continue
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} not found')
|
||||
lora_ver = network_on_disk.sd_version if network_on_disk is not None else None
|
||||
if lora_ver in ('f1', '') and shared.sd_model_type == 'f2':
|
||||
log.error(f'Network load: type=LoRA name="{name}" incompatible: Flux1 LoRA cannot be used with Flux2/Klein')
|
||||
else:
|
||||
log.error(f'Network load: type=LoRA name="{name}" detected={lora_ver} not found')
|
||||
continue
|
||||
if hasattr(sd_model, 'embedding_db'):
|
||||
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
|
||||
|
|
@ -309,6 +317,12 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
errors.display(e, 'LoRA')
|
||||
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, force=True, silent=True) # some layers may end up on cpu without hook
|
||||
|
||||
# Activate native modules loaded via diffusers path (e.g., LoKR on Flux2)
|
||||
native_nets = [net for net in l.loaded_networks if len(net.modules) > 0]
|
||||
if native_nets:
|
||||
from modules.lora import networks
|
||||
networks.network_activate()
|
||||
|
||||
if len(l.loaded_networks) > 0 and l.debug:
|
||||
log.debug(f'Network load: type=LoRA loaded={[n.name for n in l.loaded_networks]} cache={list(lora_cache)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ class NetworkOnDisk:
|
|||
return 'sc'
|
||||
if base.startswith("sd3"):
|
||||
return 'sd3'
|
||||
if base.startswith("flux2") or "klein" in base:
|
||||
return 'f2'
|
||||
if base.startswith("flux"):
|
||||
return 'f1'
|
||||
if base.startswith("hunyuan_video"):
|
||||
|
|
@ -75,6 +77,8 @@ class NetworkOnDisk:
|
|||
return 'xl'
|
||||
if arch.startswith("stable-cascade"):
|
||||
return 'sc'
|
||||
if arch.startswith("flux2") or "klein" in arch:
|
||||
return 'f2'
|
||||
if arch.startswith("flux"):
|
||||
return 'f1'
|
||||
if arch.startswith("hunyuan-video"):
|
||||
|
|
@ -86,6 +90,8 @@ class NetworkOnDisk:
|
|||
return 'sd1'
|
||||
if str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return 'sd2'
|
||||
if 'klein' in self.name.lower() or 'klein' in self.fullname.lower():
|
||||
return 'f2'
|
||||
if 'flux' in self.name.lower():
|
||||
return 'f1'
|
||||
if 'xl' in self.name.lower():
|
||||
|
|
|
|||
|
|
@ -55,3 +55,40 @@ class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-metho
|
|||
output_shape = target.shape
|
||||
updown = make_kron(output_shape, w1, w2)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
||||
|
||||
class NetworkModuleLokrChunk(NetworkModuleLokr):
|
||||
"""LoKR module that returns one chunk of the Kronecker product.
|
||||
|
||||
Used when a LoKR adapter targets a fused weight (e.g., QKV) but the model
|
||||
has separate modules (Q, K, V). Computes kron(w1, w2) on-the-fly and
|
||||
returns only the designated chunk, keeping memory usage minimal.
|
||||
"""
|
||||
def __init__(self, net, weights, chunk_index, num_chunks):
|
||||
super().__init__(net, weights)
|
||||
self.chunk_index = chunk_index
|
||||
self.num_chunks = num_chunks
|
||||
|
||||
def calc_updown(self, target):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(target.device, dtype=target.dtype)
|
||||
else:
|
||||
w1a = self.w1a.to(target.device, dtype=target.dtype)
|
||||
w1b = self.w1b.to(target.device, dtype=target.dtype)
|
||||
w1 = w1a @ w1b
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(target.device, dtype=target.dtype)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(target.device, dtype=target.dtype)
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
full_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
updown = make_kron(full_shape, w1, w2)
|
||||
updown = torch.chunk(updown, self.num_chunks, dim=0)[self.chunk_index]
|
||||
output_shape = list(updown.shape)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,217 @@
|
|||
"""Flux2/Klein-specific LoRA loading.
|
||||
|
||||
Handles:
|
||||
- Bare BFL-format keys in state dicts (adds diffusion_model. prefix for converter)
|
||||
- LoKR adapters via native module loading (bypasses diffusers PEFT system)
|
||||
|
||||
Installed via apply_patch() during pipeline loading.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from modules import shared, sd_models
|
||||
from modules.logger import log
|
||||
from modules.lora import network, network_lokr, lora_convert
|
||||
from modules.lora import lora_common as l
|
||||
|
||||
|
||||
BARE_FLUX_PREFIXES = ("single_blocks.", "double_blocks.", "img_in.", "txt_in.",
|
||||
"final_layer.", "time_in.", "single_stream_modulation.",
|
||||
"double_stream_modulation_")
|
||||
|
||||
# BFL -> diffusers module path mapping for Flux2/Klein
|
||||
F2_SINGLE_MAP = {
|
||||
'linear1': 'attn.to_qkv_mlp_proj',
|
||||
'linear2': 'attn.to_out',
|
||||
}
|
||||
F2_DOUBLE_MAP = {
|
||||
'img_attn.proj': 'attn.to_out.0',
|
||||
'txt_attn.proj': 'attn.to_add_out',
|
||||
'img_mlp.0': 'ff.linear_in',
|
||||
'img_mlp.2': 'ff.linear_out',
|
||||
'txt_mlp.0': 'ff_context.linear_in',
|
||||
'txt_mlp.2': 'ff_context.linear_out',
|
||||
}
|
||||
F2_QKV_MAP = {
|
||||
'img_attn.qkv': ('attn', ['to_q', 'to_k', 'to_v']),
|
||||
'txt_attn.qkv': ('attn', ['add_q_proj', 'add_k_proj', 'add_v_proj']),
|
||||
}
|
||||
|
||||
|
||||
def apply_lora_alphas(state_dict):
|
||||
"""Bake kohya-format .alpha scaling into lora_down weights and remove alpha keys.
|
||||
|
||||
Diffusers' Flux2 converter only handles lora_A/lora_B (or lora_down/lora_up) keys.
|
||||
Kohya-format LoRAs store per-layer alpha values as separate .alpha keys that the
|
||||
converter doesn't consume, causing a ValueError on leftover keys. This matches the
|
||||
approach used by _convert_kohya_flux_lora_to_diffusers for Flux 1.
|
||||
"""
|
||||
alpha_keys = [k for k in state_dict if k.endswith('.alpha')]
|
||||
if not alpha_keys:
|
||||
return state_dict
|
||||
for alpha_key in alpha_keys:
|
||||
base = alpha_key[:-len('.alpha')]
|
||||
down_key = f'{base}.lora_down.weight'
|
||||
if down_key not in state_dict:
|
||||
continue
|
||||
down_weight = state_dict[down_key]
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
state_dict[down_key] = down_weight * scale_down
|
||||
up_key = f'{base}.lora_up.weight'
|
||||
if up_key in state_dict:
|
||||
state_dict[up_key] = state_dict[up_key] * scale_up
|
||||
remaining = [k for k in state_dict if k.endswith('.alpha')]
|
||||
if remaining:
|
||||
log.debug(f'Network load: type=LoRA stripped {len(remaining)} orphaned alpha keys')
|
||||
for k in remaining:
|
||||
del state_dict[k]
|
||||
return state_dict
|
||||
|
||||
|
||||
def preprocess_f2_keys(state_dict):
|
||||
"""Add 'diffusion_model.' prefix to bare BFL-format keys so
|
||||
Flux2LoraLoaderMixin's format detection routes them to the converter."""
|
||||
if any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in state_dict):
|
||||
return state_dict
|
||||
if any(k.startswith(p) for k in state_dict for p in BARE_FLUX_PREFIXES):
|
||||
log.debug('Network load: type=LoRA adding diffusion_model prefix for bare BFL-format keys')
|
||||
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def try_load_lokr(name, network_on_disk, lora_scale):
|
||||
"""Try loading a Flux2/Klein LoRA as LoKR native modules.
|
||||
|
||||
Returns a Network with native modules if the state dict contains LoKR keys,
|
||||
or None to fall through to the generic diffusers path.
|
||||
"""
|
||||
t0 = time.time()
|
||||
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
||||
if not any('.lokr_w1' in k for k in state_dict):
|
||||
return None
|
||||
net = load_lokr_native(name, network_on_disk, state_dict)
|
||||
if len(net.modules) == 0:
|
||||
log.error(f'Network load: type=LoKR name="{name}" no modules matched')
|
||||
return None
|
||||
log.debug(f'Network load: type=LoKR name="{name}" native modules={len(net.modules)} scale={lora_scale}')
|
||||
l.timer.activate += time.time() - t0
|
||||
return net
|
||||
|
||||
|
||||
def load_lokr_native(name, network_on_disk, state_dict):
|
||||
"""Load Flux2 LoKR as native modules applied at inference time.
|
||||
|
||||
Stores only the compact LoKR factors (w1, w2) and computes kron(w1, w2)
|
||||
on-the-fly during weight application. For fused QKV modules in double
|
||||
blocks, NetworkModuleLokrChunk computes the full Kronecker product and
|
||||
returns only its designated Q/K/V chunk, then frees the temporary.
|
||||
"""
|
||||
prefix = "diffusion_model."
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
lora_convert.assign_network_names_to_compvis_modules(sd_model)
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.endswith('.lokr_w1'):
|
||||
continue
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
base = key[len(prefix):].rsplit('.lokr_w1', 1)[0]
|
||||
lokr_weights = {}
|
||||
for suffix in ['lokr_w1', 'lokr_w2', 'lokr_w1_a', 'lokr_w1_b', 'lokr_w2_a', 'lokr_w2_b', 'lokr_t2', 'alpha']:
|
||||
full_key = f'{prefix}{base}.{suffix}'
|
||||
if full_key in state_dict:
|
||||
lokr_weights[suffix] = state_dict[full_key]
|
||||
|
||||
parts = base.split('.')
|
||||
block_type, block_idx, module_suffix = parts[0], parts[1], '.'.join(parts[2:])
|
||||
|
||||
targets = [] # (module_path, chunk_index, num_chunks)
|
||||
if block_type == 'single_blocks' and module_suffix in F2_SINGLE_MAP:
|
||||
path = f'single_transformer_blocks.{block_idx}.{F2_SINGLE_MAP[module_suffix]}'
|
||||
targets.append((path, None, None))
|
||||
elif block_type == 'double_blocks':
|
||||
if module_suffix in F2_DOUBLE_MAP:
|
||||
path = f'transformer_blocks.{block_idx}.{F2_DOUBLE_MAP[module_suffix]}'
|
||||
targets.append((path, None, None))
|
||||
elif module_suffix in F2_QKV_MAP:
|
||||
attn_prefix, proj_keys = F2_QKV_MAP[module_suffix]
|
||||
for i, proj_key in enumerate(proj_keys):
|
||||
path = f'transformer_blocks.{block_idx}.{attn_prefix}.{proj_key}'
|
||||
targets.append((path, i, len(proj_keys)))
|
||||
|
||||
for module_path, chunk_index, num_chunks in targets:
|
||||
network_key = "lora_transformer_" + module_path.replace(".", "_")
|
||||
sd_module = sd_model.network_layer_mapping.get(network_key)
|
||||
if sd_module is None:
|
||||
log.warning(f'Network load: type=LoKR module not found in mapping: {network_key}')
|
||||
continue
|
||||
weights = network.NetworkWeights(
|
||||
network_key=network_key,
|
||||
sd_key=network_key,
|
||||
w=dict(lokr_weights),
|
||||
sd_module=sd_module,
|
||||
)
|
||||
if chunk_index is not None:
|
||||
net.modules[network_key] = network_lokr.NetworkModuleLokrChunk(net, weights, chunk_index, num_chunks)
|
||||
else:
|
||||
net.modules[network_key] = network_lokr.NetworkModuleLokr(net, weights)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
patched = False
|
||||
|
||||
|
||||
def apply_patch():
|
||||
"""Patch Flux2LoraLoaderMixin.lora_state_dict to handle bare BFL-format keys.
|
||||
|
||||
When a LoRA file has bare BFL keys (no diffusion_model. prefix), the original
|
||||
lora_state_dict won't detect them as AI toolkit format. This patch checks for
|
||||
bare keys after the original returns and adds the prefix + re-runs conversion.
|
||||
"""
|
||||
global patched
|
||||
if patched:
|
||||
return
|
||||
patched = True
|
||||
|
||||
from diffusers.loaders.lora_pipeline import Flux2LoraLoaderMixin
|
||||
original_lora_state_dict = Flux2LoraLoaderMixin.lora_state_dict.__func__
|
||||
|
||||
@classmethod # pylint: disable=no-self-argument
|
||||
def patched_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs):
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = preprocess_f2_keys(pretrained_model_name_or_path_or_dict)
|
||||
pretrained_model_name_or_path_or_dict = apply_lora_alphas(pretrained_model_name_or_path_or_dict)
|
||||
elif isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
|
||||
path = str(pretrained_model_name_or_path_or_dict)
|
||||
if path.endswith('.safetensors'):
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
with safe_open(path, framework="pt") as f:
|
||||
keys = list(f.keys())
|
||||
needs_load = (
|
||||
any(k.endswith('.alpha') for k in keys)
|
||||
or (not any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in keys)
|
||||
and any(k.startswith(p) for k in keys for p in BARE_FLUX_PREFIXES))
|
||||
)
|
||||
if needs_load:
|
||||
from safetensors.torch import load_file
|
||||
sd = load_file(path)
|
||||
sd = preprocess_f2_keys(sd)
|
||||
pretrained_model_name_or_path_or_dict = apply_lora_alphas(sd)
|
||||
except Exception:
|
||||
pass
|
||||
return original_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
Flux2LoraLoaderMixin.lora_state_dict = patched_lora_state_dict
|
||||
|
|
@ -31,6 +31,9 @@ def load_flux2(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
|
||||
|
||||
from pipelines.flux import flux2_lora
|
||||
flux2_lora.apply_patch()
|
||||
|
||||
del text_encoder
|
||||
del transformer
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ def load_flux2_klein(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
|
||||
|
||||
from pipelines.flux import flux2_lora
|
||||
flux2_lora.apply_patch()
|
||||
|
||||
del text_encoder
|
||||
del transformer
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
|
|
|||
Loading…
Reference in New Issue