diff --git a/dreambooth/dataclasses/db_config.py b/dreambooth/dataclasses/db_config.py index abdaac9..253017e 100644 --- a/dreambooth/dataclasses/db_config.py +++ b/dreambooth/dataclasses/db_config.py @@ -73,6 +73,7 @@ class DreamboothConfig(BaseModel): min_snr_gamma: float = 0.0 use_dream: bool = False dream_detail_preservation: float = 0.5 + freeze_spectral_norm: bool = False mixed_precision: str = "fp16" model_dir: str = "" model_name: str = "" diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 912afaa..784e75c 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -35,6 +35,8 @@ from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import logging as dl from diffusers.utils.torch_utils import randn_tensor from torch.cuda.profiler import profile +from torch.nn.utils.parametrizations import _SpectralNorm +from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations from torch.utils.data import Dataset from transformers import AutoTokenizer @@ -102,6 +104,38 @@ class ConditionalAccumulator: def __exit__(self, exc_type, exc_value, traceback): self.stack.__exit__(exc_type, exc_value, traceback) +# This implements spectral norm reparametrization. Unlike the pytorch +# built-in version, it computes the current spectral norm of the parameter +# when added and normalizes so that the norm remains constant. +class FrozenSpectralNorm(_SpectralNorm): + @torch.autograd.no_grad() + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12 + ) -> None: + super().__init__(weight, n_power_iterations, dim, eps) + + if weight.ndim == 1: + sigma = F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + sigma = torch.dot(self._u, torch.mv(weight_mat, self._v)) + self.register_buffer('_sigma', sigma) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + return self._sigma * F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + sigma = torch.dot(u, torch.mv(weight_mat, v)) + return weight * (self._sigma / sigma) def text_encoder_lora_state_dict(text_encoder): state_dict = {} @@ -258,6 +292,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR args.max_token_length = int(args.max_token_length) if not args.pad_tokens and args.max_token_length > 75: logger.warning("Cannot raise token length limit above 75 when pad_tokens=False") + + if args.use_lora and args.freeze_spectral_norm: + logger.warning("freeze_spectral_norm is not compatible with LORA") + args.freeze_spectral_norm = False verify_locon_installed(args) @@ -514,6 +552,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR unet, device=accelerator.device, dtype=torch.float32 ) + def add_spectral_reparametrization(unet): + for module in unet.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + weight = getattr(module, "weight", None) + register_parametrization(module, "weight", FrozenSpectralNorm(weight)) + + def remove_spectral_reparametrization(unet): + # Remove the spectral reparametrization and set all parameters to their adjusted versions + for module in unet.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + remove_parametrizations(module, "weight", leave_parametrized=True) + + # Add spectral norm reparametrization. See https://arxiv.org/abs/2303.06296 + # This can't be done until after the EMA model has been created, because the EMA model + # needs to get the standard parametrization. + if args.freeze_spectral_norm: + add_spectral_reparametrization(unet) + # Create shared unet/tenc learning rate variables learning_rate = args.learning_rate @@ -1059,6 +1115,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda") cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu") + if args.freeze_spectral_norm: + remove_spectral_reparametrization(unet) + optim_to(profiler, optimizer) if profiler is None: @@ -1435,6 +1494,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR if os.path.isfile(lora_save_file): os.remove(lora_save_file) + if args.freeze_spectral_norm: + add_spectral_reparametrization(unet) + printm("Completed saving weights.") pbar2.reset() diff --git a/helpers/ema_model.py b/helpers/ema_model.py index d3605a0..33c1e8f 100644 --- a/helpers/ema_model.py +++ b/helpers/ema_model.py @@ -85,6 +85,15 @@ class EMAModel(object): ema_state_dict = {} ema_params = self.params for key, param in new_model.state_dict().items(): + if ".parametrizations." in key: + if ".parametrizations.weight.original" in key: + # Handle reparametrized parameters + param = new_model.get_submodule(key.replace(".parametrizations.weight.original", "")).weight + key = key.replace(".parametrizations.weight.original", ".weight") + else: + # Skip extra values used in reparametrization + continue + try: ema_param = ema_params[key] except KeyError: diff --git a/index.html b/index.html index c0e9df8..300ce3a 100644 --- a/index.html +++ b/index.html @@ -308,6 +308,13 @@ data-step="0.01" id="dream_detail_preservation" data-value="0.5" data-label="DREAM detail preservation"> +