From 573d1c92bcb4cdb675eb47f6cf7e86a6ebbbbaf1 Mon Sep 17 00:00:00 2001 From: Ross Morgan-Linial Date: Wed, 17 Jan 2024 09:24:13 -0800 Subject: [PATCH] Add freeze_spectral_norm option See https://arxiv.org/abs/2303.06296 This adds an option to reparametrize the model weights using the spectral norm so that the overall norm of each weight can't change. This helps to stabilize training at high learning rates. --- dreambooth/dataclasses/db_config.py | 1 + dreambooth/train_dreambooth.py | 62 +++++++++++++++++++ helpers/ema_model.py | 9 +++ index.html | 7 +++ javascript/dreambooth.js | 1 + scripts/main.py | 5 ++ templates/defaults/defaults.json | 3 + .../defaults/dreambooth_model_config.json | 1 + templates/locales/titles_en.json | 5 ++ 9 files changed, 94 insertions(+) diff --git a/dreambooth/dataclasses/db_config.py b/dreambooth/dataclasses/db_config.py index abdaac9..253017e 100644 --- a/dreambooth/dataclasses/db_config.py +++ b/dreambooth/dataclasses/db_config.py @@ -73,6 +73,7 @@ class DreamboothConfig(BaseModel): min_snr_gamma: float = 0.0 use_dream: bool = False dream_detail_preservation: float = 0.5 + freeze_spectral_norm: bool = False mixed_precision: str = "fp16" model_dir: str = "" model_name: str = "" diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 912afaa..784e75c 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -35,6 +35,8 @@ from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import logging as dl from diffusers.utils.torch_utils import randn_tensor from torch.cuda.profiler import profile +from torch.nn.utils.parametrizations import _SpectralNorm +from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations from torch.utils.data import Dataset from transformers import AutoTokenizer @@ -102,6 +104,38 @@ class ConditionalAccumulator: def __exit__(self, exc_type, exc_value, traceback): self.stack.__exit__(exc_type, exc_value, traceback) +# This implements spectral norm reparametrization. Unlike the pytorch +# built-in version, it computes the current spectral norm of the parameter +# when added and normalizes so that the norm remains constant. +class FrozenSpectralNorm(_SpectralNorm): + @torch.autograd.no_grad() + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12 + ) -> None: + super().__init__(weight, n_power_iterations, dim, eps) + + if weight.ndim == 1: + sigma = F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + sigma = torch.dot(self._u, torch.mv(weight_mat, self._v)) + self.register_buffer('_sigma', sigma) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + return self._sigma * F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + sigma = torch.dot(u, torch.mv(weight_mat, v)) + return weight * (self._sigma / sigma) def text_encoder_lora_state_dict(text_encoder): state_dict = {} @@ -258,6 +292,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR args.max_token_length = int(args.max_token_length) if not args.pad_tokens and args.max_token_length > 75: logger.warning("Cannot raise token length limit above 75 when pad_tokens=False") + + if args.use_lora and args.freeze_spectral_norm: + logger.warning("freeze_spectral_norm is not compatible with LORA") + args.freeze_spectral_norm = False verify_locon_installed(args) @@ -514,6 +552,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR unet, device=accelerator.device, dtype=torch.float32 ) + def add_spectral_reparametrization(unet): + for module in unet.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + weight = getattr(module, "weight", None) + register_parametrization(module, "weight", FrozenSpectralNorm(weight)) + + def remove_spectral_reparametrization(unet): + # Remove the spectral reparametrization and set all parameters to their adjusted versions + for module in unet.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + remove_parametrizations(module, "weight", leave_parametrized=True) + + # Add spectral norm reparametrization. See https://arxiv.org/abs/2303.06296 + # This can't be done until after the EMA model has been created, because the EMA model + # needs to get the standard parametrization. + if args.freeze_spectral_norm: + add_spectral_reparametrization(unet) + # Create shared unet/tenc learning rate variables learning_rate = args.learning_rate @@ -1059,6 +1115,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda") cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu") + if args.freeze_spectral_norm: + remove_spectral_reparametrization(unet) + optim_to(profiler, optimizer) if profiler is None: @@ -1435,6 +1494,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR if os.path.isfile(lora_save_file): os.remove(lora_save_file) + if args.freeze_spectral_norm: + add_spectral_reparametrization(unet) + printm("Completed saving weights.") pbar2.reset() diff --git a/helpers/ema_model.py b/helpers/ema_model.py index d3605a0..33c1e8f 100644 --- a/helpers/ema_model.py +++ b/helpers/ema_model.py @@ -85,6 +85,15 @@ class EMAModel(object): ema_state_dict = {} ema_params = self.params for key, param in new_model.state_dict().items(): + if ".parametrizations." in key: + if ".parametrizations.weight.original" in key: + # Handle reparametrized parameters + param = new_model.get_submodule(key.replace(".parametrizations.weight.original", "")).weight + key = key.replace(".parametrizations.weight.original", ".weight") + else: + # Skip extra values used in reparametrization + continue + try: ema_param = ema_params[key] except KeyError: diff --git a/index.html b/index.html index c0e9df8..300ce3a 100644 --- a/index.html +++ b/index.html @@ -308,6 +308,13 @@ data-step="0.01" id="dream_detail_preservation" data-value="0.5" data-label="DREAM detail preservation"> +
+
+ + +
+