Replace some optimizers

pull/1244/head
ArrowM 2023-05-29 15:37:14 -05:00
parent 912d86704a
commit 8541193626
11 changed files with 27 additions and 549 deletions

View File

@ -1,241 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
import torch.optim
import pdb
import logging
import os
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any
def to_real(x):
if torch.is_complex(x):
return x.real
else:
return x
class DAdaptAdan(torch.optim.Optimizer):
r"""
Implements Adan with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability.
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
betas (Tuple[float, float, flot], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0.02).
no_prox (boolean):
how to perform the decoupled weight decay (default: False)
log_every (int):
Log using print every k steps, default 0 (no logging).
d0 (float):
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
growth_rate (float):
prevent the D estimate from growing faster than this multiplicative rate.
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
rate warmup effect.
"""
def __init__(self, params, lr=1.0,
betas=(0.98, 0.92, 0.99),
eps=1e-8, weight_decay=0.02,
no_prox=False,
log_every=0, d0=1e-6,
growth_rate=float('inf')):
if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
if not 0.0 < lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 < eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
no_prox=no_prox,
d = d0,
k=0,
gsq_weighted=0.0,
log_every=log_every,
growth_rate=growth_rate)
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return False
@property
def supports_flat_params(self):
return True
# Experimental implementation of Adan's restart strategy
@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['gsq_weighted'] = 0.0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization
state['step'] = 0
state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
g_sq = 0.0
sksq_weighted = 0.0
sk_l1 = 0.0
ngroups = len(self.param_groups)
group = self.param_groups[0]
gsq_weighted = group['gsq_weighted']
d = group['d']
lr = group['lr']
dlr = d*lr
no_prox = group['no_prox']
growth_rate = group['growth_rate']
log_every = group['log_every']
beta1, beta2, beta3 = group['betas']
for group in self.param_groups:
decay = group['weight_decay']
k = group['k']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach()
if state['step'] == 0:
# Previous gradient values
state['pre_grad'] = grad.clone()
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
grad_diff = grad - state['pre_grad']
grad_grad = to_real(grad * grad.conj())
update = grad + beta2 * grad_diff
update_update = to_real(update * update.conj())
exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1. - beta1))
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=dlr*(1. - beta2))
exp_avg_sq.mul_(beta3).add_(update_update, alpha=1. - beta3)
denom = exp_avg_sq.sqrt().add_(eps)
g_sq += grad_grad.div_(denom).sum().item()
s = state['s']
s.mul_(beta3).add_(grad, alpha=dlr*(1. - beta3))
sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
sk_l1 += s.abs().sum().item()
######
gsq_weighted = beta3*gsq_weighted + g_sq*(dlr**2)*(1-beta3)
d_hat = d
# if we have not done any progres, return
# if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
if sk_l1 == 0:
return loss
if lr > 0.0:
d_hat = (sksq_weighted/(1-beta3) - gsq_weighted)/sk_l1
d = max(d, min(d_hat, d*growth_rate))
if log_every > 0 and k % log_every == 0:
print(f"ng: {ngroups} lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sksq_weighted={sksq_weighted:1.1e} sk_l1={sk_l1:1.1e} gsq_weighted={gsq_weighted:1.1e}")
for group in self.param_groups:
group['gsq_weighted'] = gsq_weighted
group['d'] = d
decay = group['weight_decay']
k = group['k']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
state['step'] += 1
denom = exp_avg_sq.sqrt().add_(eps)
denom = denom.type(p.type())
update = (exp_avg + beta2 * exp_avg_diff).div_(denom)
### Take step
if no_prox:
p.data.mul_(1 - dlr * decay)
p.add_(update, alpha=-1)
else:
p.add_(update, alpha=-1)
p.data.div_(1 + dlr * decay)
state['pre_grad'].copy_(grad)
group['k'] = k + 1
return loss

View File

@ -1,264 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
import torch.optim
import pdb
import logging
import os
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any
def to_real(x):
if torch.is_complex(x):
return x.real
else:
return x
class DAdaptAdanIP(torch.optim.Optimizer):
r"""
Implements Adan with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability.
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
This IP variant uses a tighter bound than the non-IP version,
and so will typically choose larger step sizes. It has not
been as extensively tested.
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
betas (Tuple[float, float, flot], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0.02).
no_prox (boolean):
how to perform the decoupled weight decay (default: False)
log_every (int):
Log using print every k steps, default 0 (no logging).
d0 (float):
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
growth_rate (float):
prevent the D estimate from growing faster than this multiplicative rate.
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
rate warmup effect.
"""
def __init__(self, params, lr=1.0,
betas=(0.98, 0.92, 0.99),
eps=1e-8, weight_decay=0.02,
no_prox=False,
log_every=0, d0=1e-6,
growth_rate=float('inf')):
if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
if not 0.0 < lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 < eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError(
"Invalid beta parameter at index 2: {}".format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
no_prox=no_prox,
d=d0,
k=0,
numerator_weighted=0.0,
log_every=log_every,
growth_rate=growth_rate)
self.d0 = d0
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return False
@property
def supports_flat_params(self):
return True
# Experimental implementation of Adan's restart strategy
@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['numerator_weighted'] = 0.0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization
state['step'] = 0
state['s'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(
to_real(p.data), memory_format=torch.preserve_format).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
g_sq = 0.0
sksq_weighted = 0.0
sk_l1 = 0.0
ngroups = len(self.param_groups)
group = self.param_groups[0]
numerator_weighted = group['numerator_weighted']
d = group['d']
lr = group['lr']
dlr = d*lr
no_prox = group['no_prox']
growth_rate = group['growth_rate']
log_every = group['log_every']
beta1, beta2, beta3 = group['betas']
numerator_acum = 0.0
for group in self.param_groups:
decay = group['weight_decay']
k = group['k']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
state['s'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
to_real(p.data), memory_format=torch.preserve_format).detach()
if state['step'] == 0:
# Previous gradient values
state['pre_grad'] = grad.clone()
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
grad_diff = grad - state['pre_grad']
update = grad + beta2 * grad_diff
update_update = to_real(update * update.conj())
s = state['s']
denom = exp_avg_sq.sqrt().add_(eps)
numerator_acum += dlr * \
torch.dot(grad.flatten(), s.div(denom).flatten())
exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1. - beta1))
exp_avg_diff.mul_(beta2).add_(
grad_diff, alpha=dlr*(1. - beta2))
exp_avg_sq.mul_(beta3).add_(update_update, alpha=1. - beta3)
s.mul_(beta3).add_(grad, alpha=dlr*(1. - beta3))
sk_l1 += s.abs().sum().item()
######
numerator_weighted = beta3 * \
numerator_weighted + (1-beta3)*numerator_acum
d_hat = d
# if we have not done any progres, return
# if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
if sk_l1 == 0:
return loss
if lr > 0.0:
d_hat = 2*(beta3/(1-beta3))*numerator_weighted/sk_l1
d = max(d, min(d_hat, d*growth_rate))
if log_every > 0 and k % log_every == 0:
print(
f"ng: {ngroups} lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={sk_l1:1.1e} numerator_weighted={numerator_weighted:1.1e}")
for group in self.param_groups:
group['numerator_weighted'] = numerator_weighted
group['d'] = d
decay = group['weight_decay']
k = group['k']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
state['step'] += 1
denom = exp_avg_sq.sqrt().add_(eps)
denom = denom.type(p.type())
update = (exp_avg + beta2 * exp_avg_diff).div_(denom)
# Take step
if no_prox:
p.data.mul_(1 - dlr * decay)
p.add_(update, alpha=-1)
else:
p.add_(update, alpha=-1)
p.data.div_(1 + dlr * decay)
state['pre_grad'].copy_(grad)
group['k'] = k + 1
return loss

View File

@ -2,8 +2,9 @@ import json
import logging
import os
import traceback
from typing import List, Dict
from pathlib import Path
from typing import List, Dict
from pydantic import BaseModel
from dreambooth import shared # noqa

View File

@ -49,7 +49,7 @@ class SchedulerType(Enum):
CONSTANT_WITH_WARMUP = "constant_with_warmup"
def get_dadapt_with_warmup(optimizer, num_warmup_steps: int=0, unet_lr: int=1.0, tenc_lr: int=1.0):
def get_dadapt_with_warmup(optimizer, num_warmup_steps: int = 0, unet_lr: float = 1.0, tenc_lr: float = 1.0):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
See <a href='https://miro.medium.com/max/828/1*Bk4xhtvg_Su42GmiVtvigg.webp'> for an example.
@ -59,16 +59,16 @@ def get_dadapt_with_warmup(optimizer, num_warmup_steps: int=0, unet_lr: int=1.0,
num_warmup_steps (`int`, *optional*, defaults to 0):
The number of steps for the warmup phase.
unet_lr (`float`, *optional*, defaults to 1.0):
The learning rate used to to control d-dadaption for the UNET
The learning rate used to control d-dadaption for the UNET
tenc_lr (`float`, *optional*, defaults to 1.0):
The learning rate used to to control d-dadaption for the TENC
The learning rate used to control d-dadaption for the TENC
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate LR schedules for TENC and UNET.
"""
def unet_lambda(current_step: int):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(unet_lr, num_warmup_steps)))
return float(current_step) / float(max(unet_lr, num_warmup_steps))
else:
return unet_lr
@ -432,9 +432,9 @@ def get_scheduler(
If a lr scheduler has an adjustment point, this is the percentage of training steps at which to
adjust the LR.
unet_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the UNET
The learning rate used to control d-dadaption for the UNET
tenc_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the TENC
The learning rate used to control d-dadaption for the TENC
"""
@ -565,10 +565,10 @@ class UniversalScheduler:
return self.scheduler.get_lr()
#Temp conditional for dadapt optimizer console logging
# Temp conditional for dadapt optimizer console logging
def log_dadapt(disable: bool = True):
if disable:
return 0
return 0
else:
return 5
@ -583,14 +583,6 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
weight_decay=weight_decay,
)
elif optimizer == "Lion":
from lion_pytorch import Lion
return Lion(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
)
elif optimizer == "AdamW Dadaptation":
from dadaptation import DAdaptAdam
return DAdaptAdam(
@ -601,18 +593,18 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
log_every=log_dadapt(True)
)
elif optimizer == "AdanIP Dadaptation":
from dreambooth.dadapt_adan_ip import DAdaptAdanIP
return DAdaptAdanIP(
elif optimizer == "Adan Dadaptation":
from dadaptation import DAdaptAdan
return DAdaptAdan(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True)
log_every=log_dadapt(True),
)
elif optimizer == "Adan Dadaptation":
from dreambooth.dadapt_adan import DAdaptAdan
return DAdaptAdan(
elif optimizer == "SGD Dadaptation":
from dadaptation import DAdaptSGD
return DAdaptSGD(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,

View File

@ -123,12 +123,6 @@ def list_optimizer():
except:
pass
try:
from lion_pytorch import Lion
optimizer_list.append("Lion")
except:
pass
try:
from dadaptation import DAdaptAdam
optimizer_list.append("AdamW Dadaptation")
@ -136,14 +130,14 @@ def list_optimizer():
pass
try:
from dreambooth.dadapt_adan import DAdaptAdan
from dadaptation import DAdaptAdan
optimizer_list.append("Adan Dadaptation")
except:
pass
try:
from dreambooth.dadapt_adan_ip import DAdaptAdanIP
optimizer_list.append("AdanIP Dadaptation")
from dadaptation import DAdaptSGD
optimizer_list.append("SGD Dadaptation")
except:
pass

View File

@ -19,7 +19,7 @@ from dreambooth.utils import image_utils
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from dreambooth.utils.model_utils import get_checkpoint_match, \
reload_system_models, \
enable_safe_unpickle, disable_safe_unpickle, unload_system_models, xformerify
enable_safe_unpickle, disable_safe_unpickle, unload_system_models
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
get_target_module

View File

@ -8,7 +8,6 @@ from PIL import Image
from matplotlib import axes
from pandas import DataFrame
from pandas.plotting._matplotlib.style import get_standard_colors
from tensorboard.compat.proto import event_pb2
from dreambooth.shared import status

View File

@ -391,7 +391,7 @@ let db_titles = {
"Number of Hard Resets": "Number of hard resets of the lr in cosine_with_restarts scheduler.",
"Number of Samples to Generate": "How many samples to generate per subject.",
"Offset Noise": "Allows the model to learn brightness and contrast with greater detail during training. Value controls the strength of the effect, 0 disables it.",
"Optimizer": "Optimizer algorithm.",
"Optimizer": "Optimizer algorithm.\nRecommended settings (LR = Learning Rate, WD = Weight Decay):\nTorch / 8Bit AdamW - LR: 2e-6, WD: 0.01\nAdamW Adapt - LR: 0.05, WD: 0\nSGD Adapt - LR: 1, WD: 0\nAdan Adapt - LR: 0.2, WD: 0.01",
"Pad Tokens": "Pad the input images token length to this amount. You probably want to do this.",
"Pause After N Epochs": "Number of epochs after which training will be paused for the specified time. Useful if you want to give your GPU a rest.",
"Performance Wizard (WIP)": "Attempt to automatically set training parameters based on total VRAM. Still under development.",

View File

@ -9,14 +9,14 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Union, Dict
import torch
from fastapi import FastAPI
import scripts.api
from core.handlers.config import ConfigHandler
from core.handlers.models import ModelHandler, ModelManager
from core.handlers.status import StatusHandler
from core.handlers.websocket import SocketHandler
from core.modules.base.module_base import BaseModule
from fastapi import FastAPI
import scripts.api
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig, from_file
from dreambooth.sd_to_diff import extract_checkpoint

View File

@ -1,8 +1,6 @@
import logging
import os.path
import re
from collections import defaultdict
logger = logging.getLogger(__name__)

View File

@ -5,11 +5,10 @@ diffusers~=0.16.1
discord-webhook~=1.1.0
fastapi~=0.94.1
gitpython~=3.1.31
lion-pytorch~=0.1.2
Pillow==9.5.0
tqdm~=4.64.1
tqdm==4.65.0
tomesd~=0.1.2
transformers~=4.28.1 # > 4.26.x causes issues (db extension #1110)
transformers~=4.29.2 # > 4.26.x causes issues (db extension #1110)
# Tensor
tensorboard==2.13.0; sys_platform != 'darwin' or platform_machine != 'arm64'