Commit to see if Dev branch is broken

SDXLOptimizers
saunderez 2023-09-06 23:37:54 +10:00
parent f5fa07bddc
commit 6fce4da365
9 changed files with 273 additions and 1596 deletions

View File

@ -50,6 +50,7 @@ class SchedulerType(Enum):
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
"""
Returns a learning rate scheduler based on the REx (Relative Exploration) algorithm.
@ -61,6 +62,7 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
Returns:
A tuple containing the original optimizer object and a lambda function that can be used to create a PyTorch learning rate scheduler.
"""
def lr_lambda(current_step: int):
# https://arxiv.org/abs/2107.04197
max_lr = 1
@ -68,7 +70,7 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
d = 0.9
if current_step < total_training_steps:
progress = (current_step / total_training_steps)
progress = current_step / total_training_steps
div = (1 - d) + (d * (1 - progress))
return min_lr + (max_lr - min_lr) * ((1 - progress) / div)
else:
@ -77,11 +79,9 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
return LambdaLR(optimizer, lr_lambda)
# region Newer Schedulers
def get_cosine_annealing_scheduler(
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
@ -101,7 +101,7 @@ def get_cosine_annealing_scheduler(
def get_cosine_annealing_warm_restarts_scheduler(
optimizer: Optimizer, t_0: int = 25, t_mult: int = 1, eta_min: float = 1e-6
optimizer: Optimizer, t_0: int = 25, t_mult: int = 1, eta_min: float = 1e-6
):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
@ -125,7 +125,7 @@ def get_cosine_annealing_warm_restarts_scheduler(
def get_linear_schedule(
optimizer: Optimizer, start_factor: float = 0.5, total_iters: int = 500
optimizer: Optimizer, start_factor: float = 0.5, total_iters: int = 500
):
"""
Create a schedule with a learning rate that decreases at a linear rate until it reaches the number of total iters,
@ -147,7 +147,7 @@ def get_linear_schedule(
def get_constant_schedule(
optimizer: Optimizer, factor: float = 1.0, total_iters: int = 500
optimizer: Optimizer, factor: float = 1.0, total_iters: int = 500
):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
@ -168,9 +168,10 @@ def get_constant_schedule(
# endregion
# region originals
def get_constant_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, min_lr: float
optimizer: Optimizer, num_warmup_steps: int, min_lr: float
):
"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
@ -198,7 +199,7 @@ def get_constant_schedule_with_warmup(
def get_linear_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1
optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1
):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
@ -234,12 +235,12 @@ def get_linear_schedule_with_warmup(
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
@ -279,12 +280,12 @@ def get_cosine_schedule_with_warmup(
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: int = 1,
last_epoch: int = -1,
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr: float,
num_cycles: int = 1,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
@ -326,13 +327,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
def get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps,
num_training_steps,
min_lr: float,
lr_end=1e-7,
power=1.0,
last_epoch=-1,
optimizer,
num_warmup_steps,
num_training_steps,
min_lr: float,
lr_end=1e-7,
power=1.0,
last_epoch=-1,
):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
@ -490,11 +491,9 @@ def get_scheduler(
)
if name == SchedulerType.REX:
return get_rex_scheduler(
optimizer,
total_training_steps=total_training_steps
)
return get_rex_scheduler(optimizer, total_training_steps=total_training_steps)
class UniversalScheduler:
def __init__(
self,
@ -571,8 +570,8 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
try:
if optimizer == "Adafactor":
from transformers.optimization import Adafactor
adafactor = Adafactor(
params=params_to_optimize,
return Adafactor(
params_to_optimize,
lr=learning_rate,
clip_threshold=1.0,
decay_rate=-0.8,
@ -581,25 +580,23 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
scale_parameter=True,
warmup_init=False,
)
return adafactor
elif optimizer == "CAME":
from pytorch_optimizer import CAME
came = CAME(
params=params_to_optimize,
return CAME(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
fixed_decay=False,
clip_threshold=1.0,
ams_bound=False,
)
return came
)
elif optimizer == "8bit AdamW":
from bitsandbytes.optim import AdamW8bit
adamw8bit = AdamW8bit(
params=params_to_optimize,
return AdamW8bit(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
percentile_clipping=100,
@ -608,12 +605,11 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
amsgrad=False,
is_paged=False,
)
return adamw8bit
elif optimizer == "Paged 8bit AdamW":
elif optimizer == "Paged 8bit AdamW":
from bitsandbytes.optim import PagedAdamW8bit
pagedadamw8bit = PagedAdamW8bit(
params=params_to_optimize,
return PagedAdamW8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
@ -623,51 +619,47 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
amsgrad=False,
paged=True,
)
return pagedadamw8bit
elif optimizer == "Apollo":
from pytorch_optimizer import Apollo
apollo = Apollo(
params=params_to_optimize,
return Apollo(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
eight_decay_type='l2',
eight_decay_type="l2",
init_lr=None,
rebound='constant',
rebound="constant",
)
return apollo
elif optimizer == "Lion":
from pytorch_optimizer import Lion
lion = Lion(
params=params_to_optimize,
return Lion(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
fixed_decay=False,
use_gc=False,
adanorm=False
adanorm=False,
)
return lion
elif optimizer == "8bit Lion":
from bitsandbytes.optim import Lion8bit
lion8bit = Lion8bit(
params=params_to_optimize,
return Lion8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
betas=(0.9, 0.99),
weight_decay=weight_decay,
is_paged=False,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)
return lion8bit
elif optimizer == "Paged 8bit Lion":
from bitsandbytes.optim import PagedLion8bit
pagedLion8bit = PagedLion8bit(
params=params_to_optimize,
return PagedLion8bit(
params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
weight_decay=0,
@ -676,12 +668,11 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
is_paged=True,
min_8bit_size=4096,
)
return pagedLion8bit
elif optimizer == "AdamW Dadaptation":
from dadaptation import DAdaptAdam
dadaptadam = DAdaptAdam(
params=params_to_optimize,
return DAdaptAdam(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
decouple=True,
@ -689,77 +680,33 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
log_every=log_dadapt(True),
fsdp_in_use=False,
)
return dadaptadam
elif optimizer == "Lion Dadaptation":
from dadaptation import DAdaptLion
dadaptlion = DAdaptLion(
params=params_to_optimize,
return DAdaptLion(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
fsdp_in_use=False,
d0=0.000001,
)
return dadaptlion
elif optimizer == "Adan Dadaptation":
from dadaptation import DAdaptAdan
dadaptadan = DAdaptAdan(
params=params_to_optimize,
return DAdaptAdan(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
no_prox=False,
d0=0.000001,
)
return dadaptadan
elif optimizer == "AdanIP Dadaptation":
from dadaptation.experimental import DAdaptAdanIP
dadaptadanip = DAdaptAdanIP(
params=params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
no_prox=False,
d0=0.000001
)
return dadaptadanip
elif optimizer == "SGD Dadaptation":
from dadaptation import DAdaptSGD
dadaptsgd = DAdaptSGD(
params=params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
log_every=log_dadapt(True),
momentum=0.0,
fsdp_in_use=False,
d0=0.000001,
)
return dadaptsgd
elif optimizer == "Prodigy":
from pytorch_optimizer import Prodigy
prodigy = Prodigy(
params=params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
safeguard_warmup=False,
d0=1e-6,
d_coef=1.0,
bias_correction=False,
fixed_decay=False,
weight_decouple=True,
)
return prodigy
elif optimizer == "Sophia":
from pytorch_optimizer import SophiaH
sophia = SophiaH(
params=params_to_optimize,
return DAdaptSGD(
params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
@ -767,20 +714,18 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
hessian_distribution="gaussian",
p=0.01,
)
return sophia
elif optimizer == "Tiger":
from pytorch_optimizer import Tiger
tiger = Tiger(
params=params_to_optimize,
return Tiger(
params_to_optimize,
lr=learning_rate,
beta = 0.965,
beta=0.965,
weight_decay=0.01,
weight_decouple=True,
fixed_decay=False,
)
return tiger
except Exception as e:
logger.warning(f"Exception importing {optimizer}: {e}")
traceback.print_exc()
@ -796,7 +741,6 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
)
def get_noise_scheduler(args):
if args.noise_scheduler == "DEIS":
scheduler_class = DEISMultistepScheduler

View File

@ -43,8 +43,7 @@ def load_auto_settings():
lowvram = ws.cmd_opts.lowvram
config = ws.cmd_opts.config
device = ws.device
current_epoch = 0
current_step = 0
def set_model(new_model):
global sd_model
@ -162,8 +161,7 @@ class DreamState:
time_left_force_display = False
active = False
new_ui = False
current_epoch = 0
current_step = 0
def interrupt(self):
if self.status_handler:
@ -195,8 +193,6 @@ class DreamState:
"last_status": self.textinfo,
"sample_prompts": self.sample_prompts,
"active": self.active,
"current_epoch": self.current_epoch,
"current_step": self.current_step,
}
return obj
@ -311,7 +307,7 @@ def load_vars(root_path = None):
data_path, show_progress_every_n_steps, parallel_processing_allowed, dataset_filename_word_regex, dataset_filename_join_string, \
device_id, state, disable_safe_unpickle, ckptfix, medvram, lowvram, debug, profile_db, sub_quad_q_chunk_size, sub_quad_kv_chunk_size, \
sub_quad_chunk_threshold, CLIP_stop_at_last_layers, sd_model, config, force_cpu, paths, is_auto, device, orig_tensor_to, orig_layer_norm, \
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state, current_epoch, current_step
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state
script_path = os.sep.join(__file__.split(os.sep)[0:-4]) if root_path is None else root_path
models_path = os.path.join(script_path, "models")
@ -332,8 +328,7 @@ def load_vars(root_path = None):
medvram = False
lowvram = False
debug = False
current_epoch = 0
current_step = 0
profile_db = False
sub_quad_q_chunk_size = 1024
sub_quad_kv_chunk_size = None
@ -405,8 +400,7 @@ ckptfix = False
medvram = False
lowvram = False
debug = False
current_epoch = 0
current_step = 0
profile_db = False
sub_quad_q_chunk_size = 1024
sub_quad_kv_chunk_size = None

View File

@ -13,16 +13,13 @@ import traceback
from contextlib import ExitStack
from decimal import Decimal
from pathlib import Path
from accelerate.utils.megatron_lm import prepare_scheduler
from numpy import dtype, float32
from pandas import Float32Dtype
import tomesd
import torch
import torch.backends.cuda
import torch.backends.cudnn
import torch.nn.functional as F
from accelerate import Accelerator, cpu_offload
from accelerate import Accelerator
from accelerate.utils.random import set_seed as set_seed2
from diffusers import (
AutoencoderKL,
@ -59,10 +56,9 @@ from dreambooth.utils.model_utils import (
enable_safe_unpickle,
xformerify,
torch2ify, unet_attn_processors_state_dict,
)
from dreambooth.utils.text_utils import encode_hidden_state
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
patch_accelerator_for_fp16_training)
from dreambooth.webhook import send_training_update
from dreambooth.xattention import optim_to
@ -345,12 +341,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
pbar2.update()
pbar2.set_postfix(refresh=True)
# Load models and create wrapper for stable diffusion
text_encoder = text_encoder_cls.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder",
revision=args.revision,
torch_dtype=torch.float32,
)
if args.full_mixed_precision:
text_encoder = text_encoder_cls.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder",
revision=args.revision,
torch_dtype=weight_dtype,
)
else:
text_encoder = text_encoder_cls.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder",
revision=args.revision,
torch_dtype=torch.float32,
)
if args.model_type == "SDXL":
# import correct text encoder class
@ -362,11 +366,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
pbar2.update()
pbar2.set_postfix(refresh=True)
# Load models and create wrapper for stable diffusion
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder_2",
revision=args.revision,
torch_dtype=torch.float32, )
if args.full_mixed_precision:
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder_2",
revision=args.revision,
torch_dtype=weight_dtype,
)
else:
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="text_encoder_2",
revision=args.revision,
torch_dtype=torch.float32,
)
printm("Created tenc")
pbar2.set_description("Loading VAE...")
@ -376,12 +389,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
pbar2.set_description("Loading unet...")
pbar2.update()
unet = UNet2DConditionModel.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="unet",
revision=args.revision,
torch_dtype=torch.float32,
)
if args.full_mixed_precision:
unet = UNet2DConditionModel.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="unet",
revision=args.revision,
torch_dtype=weight_dtype,
)
else:
unet = UNet2DConditionModel.from_pretrained(
args.get_pretrained_model_name_or_path(),
subfolder="unet",
revision=args.revision,
torch_dtype=torch.float32,
)
if args.attention == "xformers" and not shared.force_cpu:
xformerify(unet, use_lora=args.use_lora)
@ -469,6 +490,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
)
# Create shared unet/tenc learning rate variables
learning_rate = args.learning_rate
txt_learning_rate = args.txt_learning_rate
if args.use_lora:
@ -871,8 +893,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
modules.shared.cmd_opts.disable_safe_unpickle = no_safe
global_step = resume_step = args.revision
resume_from_checkpoint = True
first_epoch = args.lifetime_epoch
global_epoch = args.lifetime_epoch
first_epoch = args.epoch
global_epoch = first_epoch
except Exception as lex:
logger.warning(f"Exception loading checkpoint: {lex}")
logger.debug(" ***** Running training *****")
@ -891,13 +913,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
logger.debug(f" First resume step: {resume_step}")
logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}")
logger.debug(f" Min SNR Gamma: {args.min_snr_gamma}")
logger.debug(f" EMA: {args.use_ema}")
logger.debug(f" UNET: {args.train_unet}")
logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
logger.debug(f" Unet LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
logger.debug(f" Tenc LR{' (Lora)' if args.use_lora and stop_text_percentage != 0 else ''}: {tenc_learning_rate}")
logger.debug(f" Full Mixed Precision: {args.full_mixed_precision}")
logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
if stop_text_percentage > 0:
logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}")
logger.debug(f" V2: {args.v2}")
os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1)
@ -911,10 +932,14 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
save_canceled = status.interrupted
save_image = False
save_model = False
save_lora = False
if not save_canceled and not save_completed:
# Check to see if the number of epochs since last save is gt the interval
if 0 < save_model_interval <= session_epoch - last_model_save:
save_model = True
if args.use_lora:
save_lora = True
last_model_save = session_epoch
# Repeat for sample images
@ -927,9 +952,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if global_step > 0:
save_image = True
save_model = True
save_lora = True
save_snapshot = False
save_lora = args.use_lora
if is_epoch_check:
if shared.status.do_save_samples:
@ -937,6 +962,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
shared.status.do_save_samples = False
if shared.status.do_save_model:
if args.use_lora:
save_lora = True
save_model = True
shared.status.do_save_model = False
@ -1012,7 +1039,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
printm("Creating pipeline.")
if args.model_type == "SDXL":
s_pipeline = StableDiffusionXLPipeline.from_pretrained(
args.get_pretrained_model_name_or_path(),
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
@ -1025,6 +1051,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
vae=vae.to(accelerator.device),
torch_dtype=weight_dtype,
revision=args.revision,
safety_checker=None,
requires_safety_checker=None,
)
xformerify(s_pipeline.unet,use_lora=args.use_lora)
else:
@ -1037,6 +1065,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
vae=vae,
torch_dtype=weight_dtype,
revision=args.revision,
safety_checker=None,
requires_safety_checker=None,
)
xformerify(s_pipeline.unet,use_lora=args.use_lora)
xformerify(s_pipeline.vae,use_lora=args.use_lora)
@ -1124,10 +1154,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
s_pipeline.scheduler.config)
s_pipeline.scheduler.config.solver_type = "bh2"
save_lora = False
elif save_diffusers:
# We are saving weights, we need to ensure revision is saved
args.save()
if "_tmp" not in weights_dir:
args.save()
try:
out_file = None
status.textinfo = (
@ -1527,7 +1559,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.s(latents, noise, timesteps)
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
@ -1544,9 +1576,17 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if args.model_type != "SDXL":
# TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth
if not args.split_loss and not with_prior_preservation:
loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(),
reduction="mean")
loss *= batch["loss_avg"]
if args.min_snr_gamma == 0.0:
loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
loss *= batch["loss_avg"]
else:
# Predict the noise residual
if model_pred.shape[1] == 6:
@ -1556,7 +1596,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
if args.min_snr_gamma == 0.0:
# Compute instance loss
@ -1570,13 +1610,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(),
reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
# Compute loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if args.min_snr_gamma == 0.0:
# Compute loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Calculate loss with min snr
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
else:
if with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
@ -1598,27 +1649,19 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
if args.min_snr_gamma == 0.0:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients and not args.use_lora:

File diff suppressed because it is too large Load Diff

View File

@ -7,12 +7,18 @@ import random
import re
import sys
from io import StringIO
from diffusers.schedulers import KarrasDiffusionSchedulers
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from PIL import features, PngImagePlugin, Image, ExifTags
import os
from typing import List, Tuple, Dict, Union
import numpy as np
import torch
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from helpers.mytqdm import mytqdm
@ -441,18 +447,8 @@ def load_image_directory(db_dir, concept: Concept, is_class: bool = True) -> Lis
return list(zip(img_paths, captions))
def open_image(image_path: str, return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
if return_pil:
return Image.open(image_path)
else:
return np.array(Image.open(image_path))
def trim_image(image: Union[np.ndarray, Image.Image], reso: Tuple[int, int]) -> Union[np.ndarray, Image.Image]:
return image[:reso[0], :reso[1]]
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image]:
# Open image with PIL
image = Image.open(image_path)
image = rotate_image_straight(image)

View File

@ -7,8 +7,7 @@ import sys
from typing import Dict
import torch
from diffusers.utils import is_xformers_available
from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import PretrainedConfig
from dreambooth import shared # noqa
@ -293,4 +292,4 @@ def torch2ify(unet):
return unet
def is_xformers_available():
import xformers
pass

View File

@ -123,13 +123,13 @@ def list_optimizer():
pass
try:
if shared.device.type != "mps":
from bitsandbytes.optim.adamw import AdamW8bit
from bitsandbytes.optim import AdamW8bit
optimizer_list.append("8bit AdamW")
except:
pass
try:
from bitsandbytes.optim.adamw import PagedAdamW8bit
from bitsandbytes.optim import PagedAdamW8bit
optimizer_list.append("Paged 8bit AdamW")
except:
pass
@ -171,13 +171,13 @@ def list_optimizer():
pass
try:
from bitsandbytes.optim.lion import Lion8bit
from bitsandbytes.optim import Lion8bit
optimizer_list.append("8bit Lion")
except:
pass
try:
from bitsandbytes.optim.lion import PagedLion8bit
from bitsandbytes.optim import PagedLion8bit
optimizer_list.append("Paged 8bit Lion")
except:
pass

View File

@ -33,7 +33,9 @@ def actual_install():
except:
revision = ""
print("If submitting an issue on github, please provide the full startup log for debugging purposes.")
print(
"If submitting an issue on github, please provide the full startup log for debugging purposes."
)
print("")
print("Initializing Dreambooth")
print(f"Dreambooth revision: {revision}")
@ -53,7 +55,7 @@ def pip_install(*args):
output = subprocess.check_output(
[sys.executable, "-m", "pip", "install"] + list(args),
stderr=subprocess.STDOUT,
)
)
for line in output.decode().split("\n"):
if "Successfully installed" in line:
print(line)
@ -61,7 +63,9 @@ def pip_install(*args):
def install_requirements():
dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False)
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
req_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "requirements.txt"
)
req_file_startup_arg = os.environ.get("REQS_FILE", "requirements_versions.txt")
if dreambooth_skip_install or req_file == req_file_startup_arg:
@ -74,12 +78,20 @@ def install_requirements():
try:
pip_install("-r", req_file)
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
if (
has_diffusers
and has_tqdm
and Version(transformers_version) < Version("4.26.1")
):
print()
print("Does your project take forever to startup?")
print("Repetitive dependency installation may be the reason.")
print("Automatic1111's base project sets strict requirements on outdated dependencies.")
print("If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.")
print(
"Automatic1111's base project sets strict requirements on outdated dependencies."
)
print(
"If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup."
)
print()
except subprocess.CalledProcessError as grepexc:
error_msg = grepexc.stdout.decode()
@ -118,23 +130,34 @@ def check_bitsandbytes():
if bitsandbytes_version != "0.41.1":
try:
print("Installing bitsandbytes")
pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1")
pip_install(
"--force-install",
"==prefer-binary",
"--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui",
"bitsandbytes==0.41.1",
)
except:
print("Bitsandbytes 0.41.1 installation failed.")
print("Some features such as 8bit optimizers will be unavailable")
print("Please install manually with")
print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'")
print(
"'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'"
)
pass
else:
if bitsandbytes_version != "0.41.1":
try:
print("Installing bitsandbytes")
pip_install("--force-install","--prefer-binary","bitsandbytes==0.41.1")
pip_install(
"--force-install", "--prefer-binary", "bitsandbytes==0.41.1"
)
except:
print("Bitsandbytes 0.41.1 installation failed")
print("Some features such as 8bit optimizers will be unavailable")
print("Install manually with")
print("'python -m pip install bitsandbytes==0.41.1 --prefer-binary --force-install'")
print(
"'python -m pip install bitsandbytes==0.41.1 --prefer-binary --force-install'"
)
pass
@ -149,9 +172,10 @@ class Dependency:
def check_versions():
import platform
from sys import platform as sys_platform
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
#Probably a bad idea but update ALL the dependencies
is_mac = sys_platform == "darwin" and platform.machine() == "arm64"
# Probably a bad idea but update ALL the dependencies
dependencies = [
Dependency(module="xformers", version="0.0.21", required=False),
Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"),
@ -159,7 +183,7 @@ def check_versions():
Dependency(module="accelerate", version="0.22.0"),
Dependency(module="diffusers", version="0.20.1"),
Dependency(module="transformers", version="4.25.1"),
Dependency(module="bitsandbytes", version="0.41.1"),
Dependency(module="bitsandbytes", version="0.41.1"),
]
launch_errors = []
@ -180,13 +204,21 @@ def check_versions():
required_version = dependency.version
required_comparison = dependency.version_comparison
if required_comparison == "min" and Version(installed_ver) < Version(required_version):
if required_comparison == "min" and Version(installed_ver) < Version(
required_version
):
if dependency.required:
launch_errors.append(f"{module} is below the required {required_version} version.")
launch_errors.append(
f"{module} is below the required {required_version} version."
)
print(f"[!] {module} version {installed_ver} installed.")
elif required_comparison == "exact" and Version(installed_ver) != Version(required_version):
launch_errors.append(f"{module} is not the required {required_version} version.")
elif required_comparison == "exact" and Version(installed_ver) != Version(
required_version
):
launch_errors.append(
f"{module} is not the required {required_version} version."
)
print(f"[!] {module} version {installed_ver} installed.")
else:
@ -204,7 +236,7 @@ def check_versions():
def print_requirement_installation_error(err):
print("# Requirement installation exception:")
for line in err.split('\n'):
for line in err.split("\n"):
line = line.strip()
if line:
print(line)
@ -213,27 +245,45 @@ def print_requirement_installation_error(err):
def print_xformers_installation_error(err):
torch_ver = importlib_metadata.version("torch")
print()
print("#######################################################################################################")
print("# XFORMERS ISSUE DETECTED #")
print("#######################################################################################################")
print(
"#######################################################################################################"
)
print(
"# XFORMERS ISSUE DETECTED #"
)
print(
"#######################################################################################################"
)
print("#")
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.21 built with torch {torch_ver})")
print("# xformers will not be available for Dreambooth. Consider upgrading to Torch 2.")
print(
f"# Dreambooth could not find a compatible version of xformers (>= 0.0.21 built with torch {torch_ver})"
)
print(
"# xformers will not be available for Dreambooth. Consider upgrading to Torch 2."
)
print("#")
print("# Xformers installation exception:")
for line in err.split('\n'):
for line in err.split("\n"):
line = line.strip()
if line:
print(line)
print("#")
print("#######################################################################################################")
print(
"#######################################################################################################"
)
def print_launch_errors(launch_errors):
print()
print("#######################################################################################################")
print("# LIBRARY ISSUE DETECTED #")
print("#######################################################################################################")
print(
"#######################################################################################################"
)
print(
"# LIBRARY ISSUE DETECTED #"
)
print(
"#######################################################################################################"
)
print("#")
print("# " + "\n# ".join(launch_errors))
print("#")
@ -242,20 +292,26 @@ def print_launch_errors(launch_errors):
print("# TROUBLESHOOTING")
print("# 1. Fully restart your project (not just the webpage)")
print("# 2. Update your A1111 project and extensions")
print("# 3. Dreambooth requirements should have installed automatically, but you can manually install them")
print(
"# 3. Dreambooth requirements should have installed automatically, but you can manually install them"
)
print("# by running the following 4 commands from the A1111 project root:")
print("cd venv/Scripts")
print("activate")
print("cd ../..")
print("pip install -r ./extensions/sd_dreambooth_extension/requirements.txt")
print("#######################################################################################################")
print(
"#######################################################################################################"
)
def check_torch_unsafe_load():
try:
from modules import safe
safe.load = safe.unsafe_torch_load
import torch
torch.load = safe.unsafe_torch_load
except:
pass

View File

@ -521,7 +521,7 @@ def on_ui_tabs():
label="Max Resolution",
step=64,
minimum=128,
value=512,
value=1024,
maximum=2048,
elem_id="max_res",
)
@ -570,7 +570,7 @@ def on_ui_tabs():
label="Offset Noise",
minimum=-1,
maximum=1,
step=0.01,
step=0.00001,
value=0,
)
db_freeze_clip_normalization = gr.Checkbox(
@ -612,8 +612,8 @@ def on_ui_tabs():
db_min_snr_gamma = gr.Slider(
label="Min SNR Gamma",
minimum=0,
maximum=10,
step=0.1,
maximum=20,
step=0.01,
visible=True,
)
db_pad_tokens = gr.Checkbox(