sd_dreambooth_extension/dreambooth/optimization.py

793 lines
28 KiB
Python

# A rework of 'optimization.py' from the original HF diffusers repo, modified to call the
# actual pytorch scheduler these are based on - providing a much bigger set of tuning params
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimizations for diffusion models."""
import math
import traceback
from enum import Enum
from typing import Optional, Union, List
from diffusers import DEISMultistepScheduler, UniPCMultistepScheduler, DDPMScheduler
from diffusers.utils import logging
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
LambdaLR,
ConstantLR,
LinearLR,
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
)
logger = logging.get_logger(__name__)
class SchedulerType(Enum):
REX = "rex"
LINEAR = "linear"
LINEAR_WITH_WARMUP = "linear_with_warmup"
COSINE = "cosine"
COSINE_ANNEALING = "cosine_annealing"
COSINE_ANNEALING_WITH_RESTARTS = "cosine_annealing_with_restarts"
COSINE_WITH_RESTARTS = "cosine_with_restarts"
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
"""
Returns a learning rate scheduler based on the REx (Relative Exploration) algorithm.
Args:
optimizer (Optimizer): The optimizer to use for training.
total_training_steps (int): The total number of training steps.
Returns:
A tuple containing the original optimizer object and a lambda function that can be used to create a PyTorch learning rate scheduler.
"""
def lr_lambda(current_step: int):
# https://arxiv.org/abs/2107.04197
max_lr = 1
min_lr = 0.00000001
d = 0.9
if current_step < total_training_steps:
progress = current_step / total_training_steps
div = (1 - d) + (d * (1 - progress))
return min_lr + (max_lr - min_lr) * ((1 - progress) / div)
else:
return min_lr
return LambdaLR(optimizer, lr_lambda)
# region Newer Schedulers
def get_cosine_annealing_scheduler(
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
See <a href='https://miro.medium.com/max/828/1*Bk4xhtvg_Su42GmiVtvigg.webp'> for an example.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
max_iter (`int`, *optional*, defaults to 500):
The number of steps for the warmup phase.
eta_min (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
Return:
`torch.optim.lr_scheduler.CosineAnnealingLR` with the appropriate schedule.
"""
return CosineAnnealingLR(optimizer, T_max=max_iter, eta_min=eta_min)
def get_cosine_annealing_warm_restarts_scheduler(
optimizer: Optimizer, t_0: int = 25, t_mult: int = 1, eta_min: float = 1e-6
):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
See <a href='https://miro.medium.com/max/828/1*Bk4xhtvg_Su42GmiVtvigg.webp'> for an example.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
t_0 (`int`, *optional*, defaults to 25):
Number of iterations for the first restart.
t_mult (`int`, *optional*, defaults to 1):
A factor increases number of iterations after a restart. Default: 1.
eta_min ('float', *optional*, defaults to 1e-6)
The minimum learning rate to adjust to.
Return:
`torch.optim.lr_scheduler.CosineAnnealingWarmRestarts` with the appropriate schedule.
"""
return CosineAnnealingWarmRestarts(
optimizer, T_0=t_0, T_mult=t_mult, eta_min=eta_min
)
def get_linear_schedule(
optimizer: Optimizer, start_factor: float = 0.5, total_iters: int = 500
):
"""
Create a schedule with a learning rate that decreases at a linear rate until it reaches the number of total iters,
after which it will run at a constant rate.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
start_factor (`float`, *optional*, defaults to 0.5):
The value the LR will be multiplied by at the start of training.
total_iters ('int', *optional*, defaults to 500):
The epoch number at which the LR will be adjusted
Return:
`torch.optim.lr_scheduler.LinearLR` with the appropriate schedule.
"""
return LinearLR(optimizer, start_factor=start_factor, total_iters=total_iters)
def get_constant_schedule(
optimizer: Optimizer, factor: float = 1.0, total_iters: int = 500
):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
factor (`float`, *optional*, defaults to 2.0):
The value the step will be divided by when total_iters is reached.
total_iters ('int', *optional*, defaults to 500):
The epoch number at which the LR will be adjusted
Return:
`torch.optim.lr_scheduler.ConstantLR` with the appropriate schedule.
"""
return ConstantLR(optimizer, factor=factor, total_iters=total_iters)
# endregion
# region originals
def get_constant_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, min_lr: float
):
"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
lamb = float(current_step) / float(max(1, num_warmup_steps))
return max(min_lr, lamb)
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
def get_linear_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1
):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return max(min_lr, float(current_step) / float(max(1, num_warmup_steps)))
return max(
0.0,
float(num_training_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return max(min_lr, float(current_step) / float(max(1, num_warmup_steps)))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: int = 1,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
num_cycles (`int`, *optional*, defaults to 1):
The number of hard restarts to use.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return max(min_lr, float(current_step) / float(max(1, num_warmup_steps)))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
if progress >= 1.0:
return 0.0
return max(
0.0,
0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps,
num_training_steps,
min_lr: float,
lr_end=1e-7,
power=1.0,
last_epoch=-1,
):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
lr_end (`float`, *optional*, defaults to 1e-7):
The end LR.
power (`float`, *optional*, defaults to 1.0):
Power factor.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
implementation at
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_init = optimizer.defaults["lr"]
if not (lr_init > lr_end):
raise ValueError(
f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})"
)
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return max(min_lr, float(current_step) / float(max(1, num_warmup_steps)))
elif current_step > num_training_steps:
return lr_end / lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = lr_init - lr_end
decay_steps = num_training_steps - num_warmup_steps
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining ** power + lr_end
return decay / lr_init # as LambdaLR multiplies by lr_init
return LambdaLR(optimizer, lr_lambda, last_epoch)
# endregion
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
total_training_steps: Optional[int] = None,
min_lr: float = 1e-6,
min_lr_scale: float = 0,
num_cycles: int = 1,
power: float = 1.0,
factor: float = 0.5,
scale_pos: float = 0.5,
unet_lr: float = 1.0,
tenc_lr: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
total_training_steps (`int``, *optional*):
The number of training steps. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
min_lr (`float`, *optional*, defaults to 1e-6):
The minimum learning rate to use after the number of max iterations is reached.
min_lr_scale('float', Target learning rate / min learning rate)
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
factor ('float', *optional*, defaults to 0.5):
Multiplication factor for constant and linear schedulers
scale_pos (`float`, *optional*, defaults to 0.5):
If a lr scheduler has an adjustment point, this is the percentage of training steps at which to
adjust the LR.
unet_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to control d-dadaption for the UNET
tenc_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to control d-dadaption for the TENC
"""
name = SchedulerType(name)
break_steps = int(total_training_steps * scale_pos)
# Newer schedulers
if name == SchedulerType.CONSTANT:
return get_constant_schedule(optimizer, factor, break_steps)
if name == SchedulerType.LINEAR:
return get_linear_schedule(optimizer, factor, break_steps)
if name == SchedulerType.COSINE_ANNEALING:
return get_cosine_annealing_scheduler(optimizer, break_steps, min_lr)
if name == SchedulerType.COSINE_ANNEALING_WITH_RESTARTS:
return get_cosine_annealing_warm_restarts_scheduler(
optimizer, int(break_steps / 2), eta_min=min_lr
)
# OG schedulers
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=num_warmup_steps, min_lr=min_lr_scale
)
if name == SchedulerType.LINEAR_WITH_WARMUP:
return get_linear_schedule_with_warmup(
optimizer, num_warmup_steps, total_training_steps, min_lr=min_lr_scale
)
if name == SchedulerType.COSINE_WITH_RESTARTS:
return get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
min_lr=min_lr_scale,
num_cycles=num_cycles,
)
if name == SchedulerType.POLYNOMIAL:
return get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
min_lr=min_lr_scale,
power=power,
)
if name == SchedulerType.COSINE:
return get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
min_lr=min_lr_scale,
num_cycles=num_cycles,
)
if name == SchedulerType.REX:
return get_rex_scheduler(
optimizer,
total_training_steps=total_training_steps
)
class UniversalScheduler:
def __init__(
self,
name: Union[str, SchedulerType],
optimizer: Optional[Optimizer],
num_warmup_steps: int,
total_training_steps: int,
total_epochs: int,
num_cycles: int = 1,
power: float = 1.0,
factor: float = 0.5,
min_lr: float = 1e-6,
scale_pos: float = 0.5,
unet_lr: float = 1.0,
tenc_lr: float = 1.0,
):
self.current_step = 0
og_schedulers = [
"constant_with_warmup",
"linear_with_warmup",
"cosine",
"cosine_with_restarts",
"polynomial",
]
self.is_torch_scheduler = name in og_schedulers
self.total_steps = total_training_steps if not self.is_torch_scheduler else total_epochs
self.scheduler = get_scheduler(
name=name,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
total_training_steps=total_training_steps,
min_lr=min_lr,
num_cycles=num_cycles,
power=power,
factor=factor,
scale_pos=scale_pos,
unet_lr=unet_lr,
tenc_lr=tenc_lr,
)
def step(self, steps: int = 1, is_epoch: bool = False):
if self.is_torch_scheduler and is_epoch:
self.current_step += steps
self.scheduler.step(self.current_step)
else:
self.current_step += steps
self.scheduler.step(self.current_step)
def state_dict(self) -> dict:
return self.scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
self.scheduler.load_state_dict(state_dict)
def get_last_lr(self) -> List[float]:
return self.scheduler.get_last_lr()
def get_lr(self) -> float:
return self.scheduler.get_lr()
# Temp conditional for dadapt optimizer console logging
def log_dadapt(disable: bool = True):
if disable:
return 0
else:
return 5
def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, params_to_optimize):
try:
if optimizer == "Adafactor":
from transformers.optimization import Adafactor
return Adafactor(
params_to_optimize,
lr=learning_rate,
clip_threshold=1.0,
decay_rate=-0.8,
weight_decay=weight_decay,
relative_step=False,
scale_parameter=True,
warmup_init=False,
)
elif optimizer == "CAME":
from pytorch_optimizer import CAME
return CAME(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
fixed_decay=False,
clip_threshold=1.0,
ams_bound=False,
)
elif optimizer == "8bit AdamW":
from bitsandbytes.optim import AdamW8bit
return AdamW8bit(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
percentile_clipping=100,
min_8bit_size=4096,
block_wise=True,
amsgrad=False,
is_paged=False,
)
elif optimizer == "Paged 8bit AdamW":
from bitsandbytes.optim import PagedAdamW8bit
return PagedAdamW8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=weight_decay,
percentile_clipping=100,
block_wise=True,
amsgrad=False,
paged=True,
)
elif optimizer == "Apollo":
from pytorch_optimizer import Apollo
return Apollo(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decay_type="l2",
init_lr=None,
rebound="constant",
)
elif optimizer == "Lion":
from pytorch_optimizer import Lion
return Lion(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
fixed_decay=False,
use_gc=False,
adanorm=False,
)
elif optimizer == "8bit Lion":
from bitsandbytes.optim import Lion8bit
return Lion8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
weight_decay=weight_decay,
is_paged=False,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)
elif optimizer == "Paged 8bit Lion":
from bitsandbytes.optim import PagedLion8bit
return PagedLion8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
weight_decay=0,
percentile_clipping=100,
block_wise=True,
is_paged=True,
min_8bit_size=4096,
)
elif optimizer == "AdamW Dadaptation":
from dadaptation import DAdaptAdam
return DAdaptAdam(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
decouple=True,
use_bias_correction=True,
log_every=log_dadapt(True),
fsdp_in_use=False,
)
elif optimizer == "Lion Dadaptation":
from dadaptation import DAdaptLion
return DAdaptLion(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
fsdp_in_use=False,
d0=0.000001,
)
elif optimizer == "Adan Dadaptation":
from dadaptation import DAdaptAdan
return DAdaptAdan(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
no_prox=False,
d0=0.000001,
)
elif optimizer == "AdanIP Dadaptation":
from dadaptation.experimental import DAdaptAdanIP
return DAdaptAdanIP(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
no_prox=False,
d0=0.000001
)
elif optimizer == "SGD Dadaptation":
from dadaptation import DAdaptSGD
return DAdaptSGD(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
momentum=0.0,
fsdp_in_use=False,
d0=0.000001,
)
elif optimizer == "Prodigy":
from pytorch_optimizer import Prodigy
return Prodigy(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
safeguard_warmup=False,
d0=1e-6,
d_coef=1.0,
bias_correction=False,
fixed_decay=False,
weight_decouple=True,
)
elif optimizer == "Sophia":
from pytorch_optimizer import SophiaH
return SophiaH(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
fixed_decay=False,
hessian_distribution="gaussian",
p=0.01,
)
elif optimizer == "Tiger":
from pytorch_optimizer import Tiger
return Tiger(
params_to_optimize,
lr=learning_rate,
beta=0.965,
weight_decay=0.01,
weight_decouple=True,
fixed_decay=False,
)
except Exception as e:
logger.warning(f"Exception importing {optimizer}: {e}")
traceback.print_exc()
print(str(e))
print("WARNING: Using default optimizer (AdamW from Torch)")
optimizer = "Torch AdamW"
from torch.optim import AdamW
return AdamW(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
)
def get_noise_scheduler(args):
if args.noise_scheduler == "DEIS":
scheduler_class = DEISMultistepScheduler
elif args.noise_scheduler == "UniPC":
scheduler_class = UniPCMultistepScheduler
else:
scheduler_class = DDPMScheduler
return scheduler_class.from_pretrained(
args.get_pretrained_model_name_or_path(), subfolder="scheduler"
)