Add freeze_spectral_norm option
See https://arxiv.org/abs/2303.06296 This adds an option to reparametrize the model weights using the spectral norm so that the overall norm of each weight can't change. This helps to stabilize training at high learning rates.pull/1441/head
parent
8207ccd854
commit
573d1c92bc
|
|
@ -73,6 +73,7 @@ class DreamboothConfig(BaseModel):
|
|||
min_snr_gamma: float = 0.0
|
||||
use_dream: bool = False
|
||||
dream_detail_preservation: float = 0.5
|
||||
freeze_spectral_norm: bool = False
|
||||
mixed_precision: str = "fp16"
|
||||
model_dir: str = ""
|
||||
model_name: str = ""
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ from diffusers.training_utils import unet_lora_state_dict
|
|||
from diffusers.utils import logging as dl
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from torch.cuda.profiler import profile
|
||||
from torch.nn.utils.parametrizations import _SpectralNorm
|
||||
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
|
@ -102,6 +104,38 @@ class ConditionalAccumulator:
|
|||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.stack.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
# This implements spectral norm reparametrization. Unlike the pytorch
|
||||
# built-in version, it computes the current spectral norm of the parameter
|
||||
# when added and normalizes so that the norm remains constant.
|
||||
class FrozenSpectralNorm(_SpectralNorm):
|
||||
@torch.autograd.no_grad()
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
n_power_iterations: int = 1,
|
||||
dim: int = 0,
|
||||
eps: float = 1e-12
|
||||
) -> None:
|
||||
super().__init__(weight, n_power_iterations, dim, eps)
|
||||
|
||||
if weight.ndim == 1:
|
||||
sigma = F.normalize(weight, dim=0, eps=self.eps)
|
||||
else:
|
||||
weight_mat = self._reshape_weight_to_matrix(weight)
|
||||
sigma = torch.dot(self._u, torch.mv(weight_mat, self._v))
|
||||
self.register_buffer('_sigma', sigma)
|
||||
|
||||
def forward(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
if weight.ndim == 1:
|
||||
return self._sigma * F.normalize(weight, dim=0, eps=self.eps)
|
||||
else:
|
||||
weight_mat = self._reshape_weight_to_matrix(weight)
|
||||
if self.training:
|
||||
self._power_method(weight_mat, self.n_power_iterations)
|
||||
u = self._u.clone(memory_format=torch.contiguous_format)
|
||||
v = self._v.clone(memory_format=torch.contiguous_format)
|
||||
sigma = torch.dot(u, torch.mv(weight_mat, v))
|
||||
return weight * (self._sigma / sigma)
|
||||
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
|
@ -258,6 +292,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.max_token_length = int(args.max_token_length)
|
||||
if not args.pad_tokens and args.max_token_length > 75:
|
||||
logger.warning("Cannot raise token length limit above 75 when pad_tokens=False")
|
||||
|
||||
if args.use_lora and args.freeze_spectral_norm:
|
||||
logger.warning("freeze_spectral_norm is not compatible with LORA")
|
||||
args.freeze_spectral_norm = False
|
||||
|
||||
verify_locon_installed(args)
|
||||
|
||||
|
|
@ -514,6 +552,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
unet, device=accelerator.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def add_spectral_reparametrization(unet):
|
||||
for module in unet.modules():
|
||||
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
||||
weight = getattr(module, "weight", None)
|
||||
register_parametrization(module, "weight", FrozenSpectralNorm(weight))
|
||||
|
||||
def remove_spectral_reparametrization(unet):
|
||||
# Remove the spectral reparametrization and set all parameters to their adjusted versions
|
||||
for module in unet.modules():
|
||||
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
||||
remove_parametrizations(module, "weight", leave_parametrized=True)
|
||||
|
||||
# Add spectral norm reparametrization. See https://arxiv.org/abs/2303.06296
|
||||
# This can't be done until after the EMA model has been created, because the EMA model
|
||||
# needs to get the standard parametrization.
|
||||
if args.freeze_spectral_norm:
|
||||
add_spectral_reparametrization(unet)
|
||||
|
||||
# Create shared unet/tenc learning rate variables
|
||||
|
||||
learning_rate = args.learning_rate
|
||||
|
|
@ -1059,6 +1115,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda")
|
||||
cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu")
|
||||
|
||||
if args.freeze_spectral_norm:
|
||||
remove_spectral_reparametrization(unet)
|
||||
|
||||
optim_to(profiler, optimizer)
|
||||
|
||||
if profiler is None:
|
||||
|
|
@ -1435,6 +1494,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if os.path.isfile(lora_save_file):
|
||||
os.remove(lora_save_file)
|
||||
|
||||
if args.freeze_spectral_norm:
|
||||
add_spectral_reparametrization(unet)
|
||||
|
||||
printm("Completed saving weights.")
|
||||
pbar2.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -85,6 +85,15 @@ class EMAModel(object):
|
|||
ema_state_dict = {}
|
||||
ema_params = self.params
|
||||
for key, param in new_model.state_dict().items():
|
||||
if ".parametrizations." in key:
|
||||
if ".parametrizations.weight.original" in key:
|
||||
# Handle reparametrized parameters
|
||||
param = new_model.get_submodule(key.replace(".parametrizations.weight.original", "")).weight
|
||||
key = key.replace(".parametrizations.weight.original", ".weight")
|
||||
else:
|
||||
# Skip extra values used in reparametrization
|
||||
continue
|
||||
|
||||
try:
|
||||
ema_param = ema_params[key]
|
||||
except KeyError:
|
||||
|
|
|
|||
|
|
@ -308,6 +308,13 @@
|
|||
data-step="0.01" id="dream_detail_preservation" data-value="0.5"
|
||||
data-label="DREAM detail preservation"></div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="form-check form-switch">
|
||||
<input class="dbInput form-check-input" type="checkbox"
|
||||
id="freeze_spectral_norm" name="freeze_spectral_norm">
|
||||
<label class="form-check-label" for="freeze_spectral_norm">Freeze Spectral Norm</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="dbInput db-slider" data-min="75" data-max="300"
|
||||
data-step="75" id="max_token_length" data-value="75"
|
||||
|
|
|
|||
|
|
@ -330,6 +330,7 @@ let db_titles = {
|
|||
"Use Concepts List": "Train multiple concepts from a JSON file or string.",
|
||||
"Use DREAM": "Enable DREAM (http://arxiv.org/abs/2312.00210). This may provide better results, but trains slower.",
|
||||
"DREAM detail preservation": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM.",
|
||||
"Freeze Spectral Norm": "Prevents the overall magnitude of weights from changing during training. This helps stabilize training at high learning rates. Not compatible with LORA.",
|
||||
"Use EMA": "Enabling this will provide better results and editability, but cost more VRAM.",
|
||||
"Use EMA for prediction": "",
|
||||
"Use EMA Weights for Inference": "Enabling this will save the EMA unet weights as the 'normal' model weights and ignore the regular unet weights.",
|
||||
|
|
|
|||
|
|
@ -652,6 +652,9 @@ def on_ui_tabs():
|
|||
value=0.5,
|
||||
visible=True,
|
||||
)
|
||||
db_freeze_spectral_norm = gr.Checkbox(
|
||||
label="Freeze Spectral Norm", value=False
|
||||
)
|
||||
db_pad_tokens = gr.Checkbox(
|
||||
label="Pad Tokens", value=True
|
||||
)
|
||||
|
|
@ -1361,6 +1364,7 @@ def on_ui_tabs():
|
|||
db_min_snr_gamma,
|
||||
db_use_dream,
|
||||
db_dream_detail_preservation,
|
||||
db_freeze_spectral_norm,
|
||||
db_pad_tokens,
|
||||
db_strict_tokens,
|
||||
db_max_token_length,
|
||||
|
|
@ -1496,6 +1500,7 @@ def on_ui_tabs():
|
|||
db_min_snr_gamma,
|
||||
db_use_dream,
|
||||
db_dream_detail_preservation,
|
||||
db_freeze_spectral_norm,
|
||||
db_mixed_precision,
|
||||
db_model_name,
|
||||
db_model_path,
|
||||
|
|
|
|||
|
|
@ -161,6 +161,9 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"freeze_spectral_norm": {
|
||||
"value": false
|
||||
},
|
||||
"max_token_length": {
|
||||
"value": 75,
|
||||
"min": 75,
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
"min_snr_gamma": 0.0,
|
||||
"use_dream": false,
|
||||
"dream_detail_preservation": 0.5,
|
||||
"freeze_spectral_norm": false,
|
||||
"mixed_precision": "fp16",
|
||||
"noise_scheduler": "DDPM",
|
||||
"num_train_epochs": 200,
|
||||
|
|
|
|||
|
|
@ -198,6 +198,11 @@
|
|||
"title": "Select how much detail DREAM preserves.",
|
||||
"description": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM."
|
||||
},
|
||||
"freeze_spectral_norm": {
|
||||
"label": "Freeze Spectral Norm",
|
||||
"title": "Freeze spectral norm of the model weights.",
|
||||
"description": "Prevents the overall magnitude of weights from changing during training. This helps stabilize training at high learning rates. Not compatible with LORA."
|
||||
},
|
||||
"train_unet": {
|
||||
"label": "Train UNET",
|
||||
"title": "Train UNET as an additional module.",
|
||||
|
|
|
|||
Loading…
Reference in New Issue