Add freeze_spectral_norm option

See https://arxiv.org/abs/2303.06296

This adds an option to reparametrize the model weights using the spectral norm so that the overall norm of each weight can't change. This helps to stabilize training at high learning rates.
pull/1441/head
Ross Morgan-Linial 2024-01-17 09:24:13 -08:00
parent 8207ccd854
commit 573d1c92bc
9 changed files with 94 additions and 0 deletions

View File

@ -73,6 +73,7 @@ class DreamboothConfig(BaseModel):
min_snr_gamma: float = 0.0
use_dream: bool = False
dream_detail_preservation: float = 0.5
freeze_spectral_norm: bool = False
mixed_precision: str = "fp16"
model_dir: str = ""
model_name: str = ""

View File

@ -35,6 +35,8 @@ from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import logging as dl
from diffusers.utils.torch_utils import randn_tensor
from torch.cuda.profiler import profile
from torch.nn.utils.parametrizations import _SpectralNorm
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from torch.utils.data import Dataset
from transformers import AutoTokenizer
@ -102,6 +104,38 @@ class ConditionalAccumulator:
def __exit__(self, exc_type, exc_value, traceback):
self.stack.__exit__(exc_type, exc_value, traceback)
# This implements spectral norm reparametrization. Unlike the pytorch
# built-in version, it computes the current spectral norm of the parameter
# when added and normalizes so that the norm remains constant.
class FrozenSpectralNorm(_SpectralNorm):
@torch.autograd.no_grad()
def __init__(
self,
weight: torch.Tensor,
n_power_iterations: int = 1,
dim: int = 0,
eps: float = 1e-12
) -> None:
super().__init__(weight, n_power_iterations, dim, eps)
if weight.ndim == 1:
sigma = F.normalize(weight, dim=0, eps=self.eps)
else:
weight_mat = self._reshape_weight_to_matrix(weight)
sigma = torch.dot(self._u, torch.mv(weight_mat, self._v))
self.register_buffer('_sigma', sigma)
def forward(self, weight: torch.Tensor) -> torch.Tensor:
if weight.ndim == 1:
return self._sigma * F.normalize(weight, dim=0, eps=self.eps)
else:
weight_mat = self._reshape_weight_to_matrix(weight)
if self.training:
self._power_method(weight_mat, self.n_power_iterations)
u = self._u.clone(memory_format=torch.contiguous_format)
v = self._v.clone(memory_format=torch.contiguous_format)
sigma = torch.dot(u, torch.mv(weight_mat, v))
return weight * (self._sigma / sigma)
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
@ -259,6 +293,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if not args.pad_tokens and args.max_token_length > 75:
logger.warning("Cannot raise token length limit above 75 when pad_tokens=False")
if args.use_lora and args.freeze_spectral_norm:
logger.warning("freeze_spectral_norm is not compatible with LORA")
args.freeze_spectral_norm = False
verify_locon_installed(args)
precision = args.mixed_precision if not shared.force_cpu else "no"
@ -514,6 +552,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
unet, device=accelerator.device, dtype=torch.float32
)
def add_spectral_reparametrization(unet):
for module in unet.modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
weight = getattr(module, "weight", None)
register_parametrization(module, "weight", FrozenSpectralNorm(weight))
def remove_spectral_reparametrization(unet):
# Remove the spectral reparametrization and set all parameters to their adjusted versions
for module in unet.modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
remove_parametrizations(module, "weight", leave_parametrized=True)
# Add spectral norm reparametrization. See https://arxiv.org/abs/2303.06296
# This can't be done until after the EMA model has been created, because the EMA model
# needs to get the standard parametrization.
if args.freeze_spectral_norm:
add_spectral_reparametrization(unet)
# Create shared unet/tenc learning rate variables
learning_rate = args.learning_rate
@ -1059,6 +1115,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda")
cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu")
if args.freeze_spectral_norm:
remove_spectral_reparametrization(unet)
optim_to(profiler, optimizer)
if profiler is None:
@ -1435,6 +1494,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if os.path.isfile(lora_save_file):
os.remove(lora_save_file)
if args.freeze_spectral_norm:
add_spectral_reparametrization(unet)
printm("Completed saving weights.")
pbar2.reset()

View File

@ -85,6 +85,15 @@ class EMAModel(object):
ema_state_dict = {}
ema_params = self.params
for key, param in new_model.state_dict().items():
if ".parametrizations." in key:
if ".parametrizations.weight.original" in key:
# Handle reparametrized parameters
param = new_model.get_submodule(key.replace(".parametrizations.weight.original", "")).weight
key = key.replace(".parametrizations.weight.original", ".weight")
else:
# Skip extra values used in reparametrization
continue
try:
ema_param = ema_params[key]
except KeyError:

View File

@ -308,6 +308,13 @@
data-step="0.01" id="dream_detail_preservation" data-value="0.5"
data-label="DREAM detail preservation"></div>
</div>
<div class="form-group">
<div class="form-check form-switch">
<input class="dbInput form-check-input" type="checkbox"
id="freeze_spectral_norm" name="freeze_spectral_norm">
<label class="form-check-label" for="freeze_spectral_norm">Freeze Spectral Norm</label>
</div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="75" data-max="300"
data-step="75" id="max_token_length" data-value="75"

View File

@ -330,6 +330,7 @@ let db_titles = {
"Use Concepts List": "Train multiple concepts from a JSON file or string.",
"Use DREAM": "Enable DREAM (http://arxiv.org/abs/2312.00210). This may provide better results, but trains slower.",
"DREAM detail preservation": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM.",
"Freeze Spectral Norm": "Prevents the overall magnitude of weights from changing during training. This helps stabilize training at high learning rates. Not compatible with LORA.",
"Use EMA": "Enabling this will provide better results and editability, but cost more VRAM.",
"Use EMA for prediction": "",
"Use EMA Weights for Inference": "Enabling this will save the EMA unet weights as the 'normal' model weights and ignore the regular unet weights.",

View File

@ -652,6 +652,9 @@ def on_ui_tabs():
value=0.5,
visible=True,
)
db_freeze_spectral_norm = gr.Checkbox(
label="Freeze Spectral Norm", value=False
)
db_pad_tokens = gr.Checkbox(
label="Pad Tokens", value=True
)
@ -1361,6 +1364,7 @@ def on_ui_tabs():
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_freeze_spectral_norm,
db_pad_tokens,
db_strict_tokens,
db_max_token_length,
@ -1496,6 +1500,7 @@ def on_ui_tabs():
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_freeze_spectral_norm,
db_mixed_precision,
db_model_name,
db_model_path,

View File

@ -161,6 +161,9 @@
"max": 1,
"step": 0.01
},
"freeze_spectral_norm": {
"value": false
},
"max_token_length": {
"value": 75,
"min": 75,

View File

@ -40,6 +40,7 @@
"min_snr_gamma": 0.0,
"use_dream": false,
"dream_detail_preservation": 0.5,
"freeze_spectral_norm": false,
"mixed_precision": "fp16",
"noise_scheduler": "DDPM",
"num_train_epochs": 200,

View File

@ -198,6 +198,11 @@
"title": "Select how much detail DREAM preserves.",
"description": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM."
},
"freeze_spectral_norm": {
"label": "Freeze Spectral Norm",
"title": "Freeze spectral norm of the model weights.",
"description": "Prevents the overall magnitude of weights from changing during training. This helps stabilize training at high learning rates. Not compatible with LORA."
},
"train_unet": {
"label": "Train UNET",
"title": "Train UNET as an additional module.",