fix(lora): handle kohya-format alpha keys in Flux2/Klein LoRA loading

pull/4700/head
CalamitousFelicitousness 2026-03-22 21:03:14 +00:00
parent 091f31d4bf
commit 9719290ceb
1 changed files with 47 additions and 7 deletions

View File

@ -9,6 +9,7 @@ Installed via apply_patch() during pipeline loading.
import os
import time
import torch
from modules import shared, sd_models
from modules.logger import log
from modules.lora import network, network_lokr, lora_convert
@ -38,6 +39,43 @@ F2_QKV_MAP = {
}
def apply_lora_alphas(state_dict):
"""Bake kohya-format .alpha scaling into lora_down weights and remove alpha keys.
Diffusers' Flux2 converter only handles lora_A/lora_B (or lora_down/lora_up) keys.
Kohya-format LoRAs store per-layer alpha values as separate .alpha keys that the
converter doesn't consume, causing a ValueError on leftover keys. This matches the
approach used by _convert_kohya_flux_lora_to_diffusers for Flux 1.
"""
alpha_keys = [k for k in state_dict if k.endswith('.alpha')]
if not alpha_keys:
return state_dict
for alpha_key in alpha_keys:
base = alpha_key[:-len('.alpha')]
down_key = f'{base}.lora_down.weight'
if down_key not in state_dict:
continue
down_weight = state_dict[down_key]
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
state_dict[down_key] = down_weight * scale_down
up_key = f'{base}.lora_up.weight'
if up_key in state_dict:
state_dict[up_key] = state_dict[up_key] * scale_up
remaining = [k for k in state_dict if k.endswith('.alpha')]
if remaining:
log.debug(f'Network load: type=LoRA stripped {len(remaining)} orphaned alpha keys')
for k in remaining:
del state_dict[k]
return state_dict
def preprocess_f2_keys(state_dict):
"""Add 'diffusion_model.' prefix to bare BFL-format keys so
Flux2LoraLoaderMixin's format detection routes them to the converter."""
@ -45,7 +83,7 @@ def preprocess_f2_keys(state_dict):
return state_dict
if any(k.startswith(p) for k in state_dict for p in BARE_FLUX_PREFIXES):
log.debug('Network load: type=LoRA adding diffusion_model prefix for bare BFL-format keys')
return {f"diffusion_model.{k}": v for k, v in state_dict.items()}
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
return state_dict
@ -154,6 +192,7 @@ def apply_patch():
def patched_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs):
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = preprocess_f2_keys(pretrained_model_name_or_path_or_dict)
pretrained_model_name_or_path_or_dict = apply_lora_alphas(pretrained_model_name_or_path_or_dict)
elif isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
path = str(pretrained_model_name_or_path_or_dict)
if path.endswith('.safetensors'):
@ -161,15 +200,16 @@ def apply_patch():
from safetensors import safe_open
with safe_open(path, framework="pt") as f:
keys = list(f.keys())
needs_prefix = (
not any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in keys)
and any(k.startswith(p) for k in keys for p in BARE_FLUX_PREFIXES)
needs_load = (
any(k.endswith('.alpha') for k in keys)
or (not any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in keys)
and any(k.startswith(p) for k in keys for p in BARE_FLUX_PREFIXES))
)
if needs_prefix:
log.debug('Network load: type=LoRA adding diffusion_model prefix for bare BFL-format keys')
if needs_load:
from safetensors.torch import load_file
sd = load_file(path)
pretrained_model_name_or_path_or_dict = {f"diffusion_model.{k}": v for k, v in sd.items()}
sd = preprocess_f2_keys(sd)
pretrained_model_name_or_path_or_dict = apply_lora_alphas(sd)
except Exception:
pass
return original_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs)