1460 lines
59 KiB
Python
1460 lines
59 KiB
Python
# Borrowed heavily from https://github.com/bmaltais/kohya_ss/blob/master/train_db.py and
|
|
# https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth
|
|
# With some custom bits sprinkled in and some stuff from OG diffusers as well.
|
|
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import os
|
|
import time
|
|
import traceback
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
|
|
import importlib_metadata
|
|
import torch
|
|
import torch.backends.cuda
|
|
import torch.backends.cudnn
|
|
import torch.utils.checkpoint
|
|
import tomesd
|
|
from accelerate import Accelerator
|
|
from accelerate.utils.random import set_seed as set_seed2
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DiffusionPipeline,
|
|
UNet2DConditionModel,
|
|
DEISMultistepScheduler,
|
|
UniPCMultistepScheduler
|
|
)
|
|
from diffusers.utils import logging as dl, is_xformers_available
|
|
from packaging import version
|
|
from tensorflow.python.framework.random_seed import set_seed as set_seed1
|
|
from torch.cuda.profiler import profile
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
from dreambooth import shared
|
|
from dreambooth.dataclasses.prompt_data import PromptData
|
|
from dreambooth.dataclasses.train_result import TrainResult
|
|
from dreambooth.dataset.bucket_sampler import BucketSampler
|
|
from dreambooth.dataset.sample_dataset import SampleDataset
|
|
from dreambooth.deis_velocity import get_velocity
|
|
from dreambooth.diff_to_sd import compile_checkpoint, copy_diffusion_model
|
|
from dreambooth.memory import find_executable_batch_size
|
|
from dreambooth.optimization import UniversalScheduler, get_optimizer, get_noise_scheduler
|
|
from dreambooth.shared import status
|
|
from dreambooth.utils.gen_utils import generate_classifiers, generate_dataset
|
|
from dreambooth.utils.image_utils import db_save_image, get_scheduler_class
|
|
from dreambooth.utils.model_utils import (
|
|
unload_system_models,
|
|
import_model_class_from_model_name_or_path,
|
|
disable_safe_unpickle,
|
|
enable_safe_unpickle,
|
|
xformerify,
|
|
torch2ify,
|
|
)
|
|
from dreambooth.utils.text_utils import encode_hidden_state
|
|
from dreambooth.utils.utils import cleanup, printm, verify_locon_installed
|
|
from dreambooth.webhook import send_training_update
|
|
from dreambooth.xattention import optim_to
|
|
from helpers.ema_model import EMAModel
|
|
from helpers.log_parser import LogParser
|
|
from helpers.mytqdm import mytqdm
|
|
from lora_diffusion.extra_networks import save_extra_networks
|
|
from lora_diffusion.lora import (
|
|
save_lora_weight,
|
|
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
get_target_module,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
# define a Handler which writes DEBUG messages or higher to the sys.stderr
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.DEBUG)
|
|
logger.addHandler(console)
|
|
logger.setLevel(logging.DEBUG)
|
|
dl.set_verbosity_error()
|
|
|
|
last_samples = []
|
|
last_prompts = []
|
|
|
|
|
|
|
|
try:
|
|
diff_version = importlib_metadata.version("diffusers")
|
|
version_string = diff_version.split(".")
|
|
major_version = int(version_string[0])
|
|
minor_version = int(version_string[1])
|
|
patch_version = int(version_string[2])
|
|
if minor_version < 14 or (minor_version == 14 and patch_version <= 0):
|
|
print(
|
|
"The version of diffusers is less than or equal to 0.14.0. Performing monkey-patch..."
|
|
)
|
|
DEISMultistepScheduler.get_velocity = get_velocity
|
|
UniPCMultistepScheduler.get_velocity = get_velocity
|
|
else:
|
|
print(
|
|
"The version of diffusers is greater than 0.14.0, hopefully they merged the PR by now"
|
|
)
|
|
except:
|
|
print("Exception monkey-patching DEIS scheduler.")
|
|
|
|
export_diffusers = False
|
|
diffusers_dir = ""
|
|
try:
|
|
from core.handlers.config import ConfigHandler
|
|
from core.handlers.models import ModelHandler
|
|
ch = ConfigHandler()
|
|
mh = ModelHandler()
|
|
export_diffusers = ch.get_item("export_diffusers", "dreambooth", True)
|
|
diffusers_dir = os.path.join(mh.models_path, "diffusers")
|
|
except:
|
|
pass
|
|
|
|
def dadapt(optimizer):
|
|
if optimizer == "AdamW Dadaptation" or optimizer == "Adan Dadaptation":
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def set_seed(deterministic: bool):
|
|
if deterministic:
|
|
torch.backends.cudnn.deterministic = True
|
|
seed = 0
|
|
set_seed1(seed)
|
|
set_seed2(seed)
|
|
else:
|
|
torch.backends.cudnn.deterministic = False
|
|
|
|
|
|
def current_prior_loss(args, current_epoch):
|
|
if not args.prior_loss_scale:
|
|
return args.prior_loss_weight
|
|
if not args.prior_loss_target:
|
|
args.prior_loss_target = 150
|
|
if not args.prior_loss_weight_min:
|
|
args.prior_loss_weight_min = 0.1
|
|
if current_epoch >= args.prior_loss_target:
|
|
return args.prior_loss_weight_min
|
|
percentage_completed = current_epoch / args.prior_loss_target
|
|
prior = (
|
|
args.prior_loss_weight * (1 - percentage_completed)
|
|
+ args.prior_loss_weight_min * percentage_completed
|
|
)
|
|
printm(f"Prior: {prior}")
|
|
return prior
|
|
|
|
|
|
def stop_profiler(profiler):
|
|
if profiler is not None:
|
|
try:
|
|
print("Stopping profiler.")
|
|
profiler.stop()
|
|
except:
|
|
pass
|
|
|
|
|
|
def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|
"""
|
|
@param class_gen_method: Image Generation Library.
|
|
@return: TrainResult
|
|
"""
|
|
args = shared.db_model_config
|
|
logging_dir = Path(args.model_dir, "logging")
|
|
log_parser = LogParser()
|
|
|
|
result = TrainResult
|
|
result.config = args
|
|
|
|
enable_tomesd = args.enable_tomesd
|
|
enable_tomesd = True
|
|
|
|
set_seed(args.deterministic)
|
|
|
|
@find_executable_batch_size(
|
|
starting_batch_size=args.train_batch_size,
|
|
starting_grad_size=args.gradient_accumulation_steps,
|
|
logging_dir=logging_dir,
|
|
)
|
|
def inner_loop(train_batch_size: int, gradient_accumulation_steps: int, profiler: profile):
|
|
|
|
text_encoder = None
|
|
global last_samples
|
|
global last_prompts
|
|
stop_text_percentage = args.stop_text_encoder
|
|
if not args.train_unet:
|
|
stop_text_percentage = 1
|
|
n_workers = 0
|
|
args.max_token_length = int(args.max_token_length)
|
|
if not args.pad_tokens and args.max_token_length > 75:
|
|
print("Cannot raise token length limit above 75 when pad_tokens=False")
|
|
|
|
verify_locon_installed(args)
|
|
|
|
precision = args.mixed_precision if not shared.force_cpu else "no"
|
|
|
|
weight_dtype = torch.float32
|
|
if precision == "fp16":
|
|
weight_dtype = torch.float16
|
|
elif precision == "bf16":
|
|
weight_dtype = torch.bfloat16
|
|
|
|
try:
|
|
accelerator = Accelerator(
|
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
mixed_precision=precision,
|
|
log_with="tensorboard",
|
|
project_dir=logging_dir,
|
|
cpu=shared.force_cpu,
|
|
)
|
|
|
|
run_name = "dreambooth.events"
|
|
max_log_size = 250 * 1024 # specify the maximum log size
|
|
|
|
except Exception as e:
|
|
if "AcceleratorState" in str(e):
|
|
msg = "Change in precision detected, please restart the webUI entirely to use new precision."
|
|
else:
|
|
msg = f"Exception initializing accelerator: {e}"
|
|
print(msg)
|
|
result.msg = msg
|
|
result.config = args
|
|
stop_profiler(profiler)
|
|
return result
|
|
# Currently, it's not possible to do gradient accumulation when training two models with
|
|
# accelerate.accumulate This will be enabled soon in accelerate. For now, we don't allow gradient
|
|
# accumulation when training two models.
|
|
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
|
if (
|
|
stop_text_percentage != 0
|
|
and gradient_accumulation_steps > 1
|
|
and accelerator.num_processes > 1
|
|
):
|
|
msg = (
|
|
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
|
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future. Text "
|
|
"encoder training will be disabled."
|
|
)
|
|
print(msg)
|
|
status.textinfo = msg
|
|
stop_text_percentage = 0
|
|
count, instance_prompts, class_prompts = generate_classifiers(
|
|
args, class_gen_method=class_gen_method, accelerator=accelerator, ui=False
|
|
)
|
|
if status.interrupted:
|
|
result.msg = "Training interrupted."
|
|
stop_profiler(profiler)
|
|
return result
|
|
|
|
if class_gen_method == "Native Diffusers" and count > 0:
|
|
unload_system_models()
|
|
|
|
def create_vae():
|
|
vae_path = (
|
|
args.pretrained_vae_name_or_path
|
|
if args.pretrained_vae_name_or_path
|
|
else args.get_pretrained_model_name_or_path()
|
|
)
|
|
disable_safe_unpickle()
|
|
new_vae = AutoencoderKL.from_pretrained(
|
|
vae_path,
|
|
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
|
revision=args.revision,
|
|
)
|
|
enable_safe_unpickle()
|
|
new_vae.requires_grad_(False)
|
|
new_vae.to(accelerator.device, dtype=weight_dtype)
|
|
return new_vae
|
|
|
|
disable_safe_unpickle()
|
|
# Load the tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
os.path.join(args.get_pretrained_model_name_or_path(), "tokenizer"),
|
|
revision=args.revision,
|
|
use_fast=False,
|
|
)
|
|
|
|
# import correct text encoder class
|
|
text_encoder_cls = import_model_class_from_model_name_or_path(
|
|
args.get_pretrained_model_name_or_path(), args.revision
|
|
)
|
|
|
|
# Load models and create wrapper for stable diffusion
|
|
text_encoder = text_encoder_cls.from_pretrained(
|
|
args.get_pretrained_model_name_or_path(),
|
|
subfolder="text_encoder",
|
|
revision=args.revision,
|
|
torch_dtype=torch.float32,
|
|
)
|
|
printm("Created tenc")
|
|
vae = create_vae()
|
|
printm("Created vae")
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(
|
|
args.get_pretrained_model_name_or_path(),
|
|
subfolder="unet",
|
|
revision=args.revision,
|
|
torch_dtype=torch.float32,
|
|
)
|
|
unet = torch2ify(unet)
|
|
|
|
# Check that all trainable models are in full precision
|
|
low_precision_error_string = (
|
|
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
|
" doing mixed precision training. copy of the weights should still be float32."
|
|
)
|
|
if args.attention == "xformers" and not shared.force_cpu:
|
|
if is_xformers_available():
|
|
import xformers
|
|
|
|
xformers_version = version.parse(xformers.__version__)
|
|
if xformers_version == version.parse("0.0.16"):
|
|
logger.warning(
|
|
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"xformers is not available. Make sure it is installed correctly"
|
|
)
|
|
xformerify(unet)
|
|
xformerify(vae)
|
|
|
|
if accelerator.unwrap_model(unet).dtype != torch.float32:
|
|
print(
|
|
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
|
)
|
|
|
|
if (
|
|
args.stop_text_encoder != 0
|
|
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
|
|
):
|
|
print(
|
|
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
|
|
f" {low_precision_error_string}"
|
|
)
|
|
|
|
# Enable TF32 for faster training on Ampere GPUs,
|
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
|
try:
|
|
# Apparently, some versions of torch don't have a cuda_version flag? IDK, but it breaks my runpod.
|
|
if (
|
|
torch.cuda.is_available()
|
|
and float(torch.cuda_version) >= 11.0
|
|
and args.tf32_enable
|
|
):
|
|
print("Attempting to enable TF32.")
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
except:
|
|
pass
|
|
|
|
if args.gradient_checkpointing:
|
|
if args.train_unet:
|
|
unet.enable_gradient_checkpointing()
|
|
if stop_text_percentage != 0:
|
|
text_encoder.gradient_checkpointing_enable()
|
|
if args.use_lora:
|
|
text_encoder.text_model.embeddings.requires_grad_(True)
|
|
else:
|
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
ema_model = None
|
|
if args.use_ema:
|
|
if os.path.exists(
|
|
os.path.join(
|
|
args.get_pretrained_model_name_or_path(),
|
|
"ema_unet",
|
|
"diffusion_pytorch_model.safetensors",
|
|
)
|
|
):
|
|
ema_unet = UNet2DConditionModel.from_pretrained(
|
|
args.get_pretrained_model_name_or_path(),
|
|
subfolder="ema_unet",
|
|
revision=args.revision,
|
|
torch_dtype=torch.float32,
|
|
)
|
|
if args.attention == "xformers" and not shared.force_cpu:
|
|
xformerify(ema_unet)
|
|
|
|
ema_model = EMAModel(
|
|
ema_unet, device=accelerator.device, dtype=weight_dtype
|
|
)
|
|
del ema_unet
|
|
else:
|
|
ema_model = EMAModel(
|
|
unet, device=accelerator.device, dtype=weight_dtype
|
|
)
|
|
|
|
if args.use_lora or not args.train_unet:
|
|
unet.requires_grad_(False)
|
|
|
|
unet_lora_params = None
|
|
text_encoder_lora_params = None
|
|
lora_path = None
|
|
lora_txt = None
|
|
|
|
if args.use_lora:
|
|
if args.lora_model_name:
|
|
lora_path = os.path.join(args.model_dir, "loras", args.lora_model_name)
|
|
lora_txt = lora_path.replace(".pt", "_txt.pt")
|
|
|
|
if not os.path.exists(lora_path) or not os.path.isfile(lora_path):
|
|
lora_path = None
|
|
lora_txt = None
|
|
|
|
injectable_lora = get_target_module("injection", args.use_lora_extended)
|
|
target_module = get_target_module("module", args.use_lora_extended)
|
|
|
|
unet_lora_params, _ = injectable_lora(
|
|
unet,
|
|
r=args.lora_unet_rank,
|
|
loras=lora_path,
|
|
target_replace_module=target_module,
|
|
)
|
|
|
|
if stop_text_percentage != 0:
|
|
text_encoder.requires_grad_(False)
|
|
inject_trainable_txt_lora = get_target_module("injection", False)
|
|
text_encoder_lora_params, _ = inject_trainable_txt_lora(
|
|
text_encoder,
|
|
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
r=args.lora_txt_rank,
|
|
loras=lora_txt,
|
|
)
|
|
printm("Lora loaded")
|
|
cleanup()
|
|
printm("Cleaned")
|
|
|
|
args.learning_rate = args.lora_learning_rate
|
|
if stop_text_percentage != 0:
|
|
params_to_optimize = [
|
|
{
|
|
"params": itertools.chain(*unet_lora_params),
|
|
"lr": args.lora_learning_rate,
|
|
},
|
|
{
|
|
"params": itertools.chain(*text_encoder_lora_params),
|
|
"lr": args.lora_txt_learning_rate,
|
|
},
|
|
]
|
|
else:
|
|
params_to_optimize = itertools.chain(*unet_lora_params)
|
|
|
|
elif stop_text_percentage != 0:
|
|
if args.train_unet:
|
|
params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
|
|
else:
|
|
params_to_optimize = itertools.chain(text_encoder.parameters())
|
|
else:
|
|
params_to_optimize = unet.parameters()
|
|
|
|
optimizer = get_optimizer(args, params_to_optimize)
|
|
optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay
|
|
optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm
|
|
noise_scheduler = get_noise_scheduler(args)
|
|
|
|
def cleanup_memory():
|
|
try:
|
|
if unet:
|
|
del unet
|
|
if text_encoder:
|
|
del text_encoder
|
|
if tokenizer:
|
|
del tokenizer
|
|
if optimizer:
|
|
del optimizer
|
|
if train_dataloader:
|
|
del train_dataloader
|
|
if train_dataset:
|
|
del train_dataset
|
|
if lr_scheduler:
|
|
del lr_scheduler
|
|
if vae:
|
|
del vae
|
|
if unet_lora_params:
|
|
del unet_lora_params
|
|
except:
|
|
pass
|
|
cleanup(True)
|
|
|
|
if args.cache_latents:
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
vae.requires_grad_(False)
|
|
vae.eval()
|
|
|
|
if status.interrupted:
|
|
result.msg = "Training interrupted."
|
|
stop_profiler(profiler)
|
|
return result
|
|
|
|
printm("Loading dataset...")
|
|
train_dataset = generate_dataset(
|
|
model_name=args.model_name,
|
|
instance_prompts=instance_prompts,
|
|
class_prompts=class_prompts,
|
|
batch_size=train_batch_size,
|
|
tokenizer=tokenizer,
|
|
vae=vae if args.cache_latents else None,
|
|
debug=False,
|
|
model_dir=args.model_dir,
|
|
)
|
|
|
|
printm("Dataset loaded.")
|
|
|
|
if args.cache_latents:
|
|
printm("Unloading vae.")
|
|
del vae
|
|
# Preserve reference to vae for later checks
|
|
vae = None
|
|
cleanup()
|
|
if status.interrupted:
|
|
result.msg = "Training interrupted."
|
|
stop_profiler(profiler)
|
|
return result
|
|
|
|
if train_dataset.__len__ == 0:
|
|
msg = "Please provide a directory with actual images in it."
|
|
print(msg)
|
|
status.textinfo = msg
|
|
cleanup_memory()
|
|
result.msg = msg
|
|
result.config = args
|
|
stop_profiler(profiler)
|
|
return result
|
|
|
|
def collate_fn(examples):
|
|
input_ids = [example["input_ids"] for example in examples]
|
|
pixel_values = [example["image"] for example in examples]
|
|
types = [example["is_class"] for example in examples]
|
|
weights = [
|
|
current_prior_loss_weight if example["is_class"] else 1.0
|
|
for example in examples
|
|
]
|
|
loss_avg = 0
|
|
for weight in weights:
|
|
loss_avg += weight
|
|
loss_avg /= len(weights)
|
|
pixel_values = torch.stack(pixel_values)
|
|
if not args.cache_latents:
|
|
pixel_values = pixel_values.to(
|
|
memory_format=torch.contiguous_format
|
|
).float()
|
|
input_ids = torch.cat(input_ids, dim=0)
|
|
|
|
batch_data = {
|
|
"input_ids": input_ids,
|
|
"images": pixel_values,
|
|
"types": types,
|
|
"loss_avg": loss_avg,
|
|
}
|
|
return batch_data
|
|
|
|
sampler = BucketSampler(train_dataset, train_batch_size)
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=1,
|
|
batch_sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
num_workers=n_workers,
|
|
)
|
|
|
|
max_train_steps = args.num_train_epochs * len(train_dataset)
|
|
|
|
# This is separate, because optimizer.step is only called once per "step" in training, so it's not
|
|
# affected by batch size
|
|
sched_train_steps = args.num_train_epochs * train_dataset.num_train_images
|
|
|
|
lr_scale_pos = args.lr_scale_pos
|
|
if class_prompts:
|
|
lr_scale_pos *= 2
|
|
|
|
lr_scheduler = UniversalScheduler(
|
|
name=args.lr_scheduler,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=args.lr_warmup_steps,
|
|
total_training_steps=sched_train_steps,
|
|
min_lr=args.learning_rate_min,
|
|
total_epochs=args.num_train_epochs,
|
|
num_cycles=args.lr_cycles,
|
|
power=args.lr_power,
|
|
factor=args.lr_factor,
|
|
scale_pos=lr_scale_pos,
|
|
unet_lr=args.lora_learning_rate,
|
|
tenc_lr=args.lora_txt_learning_rate,
|
|
)
|
|
|
|
# create ema, fix OOM
|
|
if args.use_ema:
|
|
if stop_text_percentage != 0:
|
|
(
|
|
ema_model.model,
|
|
unet,
|
|
text_encoder,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
) = accelerator.prepare(
|
|
ema_model.model,
|
|
unet,
|
|
text_encoder,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
)
|
|
else:
|
|
(
|
|
ema_model.model,
|
|
unet,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
) = accelerator.prepare(
|
|
ema_model.model, unet, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
else:
|
|
if stop_text_percentage != 0:
|
|
(
|
|
unet,
|
|
text_encoder,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
) = accelerator.prepare(
|
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
else:
|
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
unet, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
|
|
if not args.cache_latents and vae is not None:
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
if stop_text_percentage == 0:
|
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
# Afterwards we recalculate our number of training epochs
|
|
# We need to initialize the trackers we use, and also store our configuration.
|
|
# The trackers will initialize automatically on the main process.
|
|
if accelerator.is_main_process:
|
|
accelerator.init_trackers("dreambooth")
|
|
|
|
# Train!
|
|
total_batch_size = (
|
|
train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
|
)
|
|
max_train_epochs = args.num_train_epochs
|
|
# we calculate our number of tenc training epochs
|
|
text_encoder_epochs = round(max_train_epochs * stop_text_percentage)
|
|
global_step = 0
|
|
global_epoch = 0
|
|
session_epoch = 0
|
|
first_epoch = 0
|
|
resume_step = 0
|
|
last_model_save = 0
|
|
last_image_save = 0
|
|
resume_from_checkpoint = False
|
|
new_hotness = os.path.join(
|
|
args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}"
|
|
)
|
|
if os.path.exists(new_hotness):
|
|
accelerator.print(f"Resuming from checkpoint {new_hotness}")
|
|
|
|
try:
|
|
import modules.shared
|
|
no_safe = modules.shared.cmd_opts.disable_safe_unpickle
|
|
modules.shared.cmd_opts.disable_safe_unpickle = True
|
|
except:
|
|
no_safe = False
|
|
|
|
try:
|
|
import modules.shared
|
|
accelerator.load_state(new_hotness)
|
|
modules.shared.cmd_opts.disable_safe_unpickle = no_safe
|
|
global_step = resume_step = args.revision
|
|
resume_from_checkpoint = True
|
|
first_epoch = args.epoch
|
|
global_epoch = first_epoch
|
|
except Exception as lex:
|
|
print(f"Exception loading checkpoint: {lex}")
|
|
|
|
#if shared.in_progress:
|
|
# print(" ***** OOM detected. Resuming from last step *****")
|
|
# max_train_steps = max_train_steps - shared.in_progress_step
|
|
# max_train_epochs = max_train_epochs - shared.in_progress_epoch
|
|
# session_epoch = shared.in_progress_epoch
|
|
# text_encoder_epochs = (shared.in_progress_epoch/max_train_epochs)*text_encoder_epochs
|
|
#else:
|
|
# shared.in_progress = True
|
|
|
|
print(" ***** Running training *****")
|
|
if shared.force_cpu:
|
|
print(f" TRAINING WITH CPU ONLY")
|
|
print(f" Num batches each epoch = {len(train_dataset) // train_batch_size}")
|
|
print(f" Num Epochs = {max_train_epochs}")
|
|
print(f" Batch Size Per Device = {train_batch_size}")
|
|
print(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
|
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
|
print(f" Text Encoder Epochs: {text_encoder_epochs}")
|
|
print(f" Total optimization steps = {sched_train_steps}")
|
|
print(f" Total training steps = {max_train_steps}")
|
|
print(f" Resuming from checkpoint: {resume_from_checkpoint}")
|
|
print(f" First resume epoch: {first_epoch}")
|
|
print(f" First resume step: {resume_step}")
|
|
print(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
|
|
print(f" Gradient Checkpointing: {args.gradient_checkpointing}")
|
|
print(f" EMA: {args.use_ema}")
|
|
print(f" UNET: {args.train_unet}")
|
|
print(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
|
|
print(f" LR: {args.learning_rate}")
|
|
if args.use_lora_extended:
|
|
print(f" LoRA Extended: {args.use_lora_extended}")
|
|
if args.use_lora and stop_text_percentage > 0:
|
|
print(f" LoRA Text Encoder LR: {args.lora_txt_learning_rate}")
|
|
print(f" V2: {args.v2}")
|
|
|
|
os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1)
|
|
|
|
def check_save(is_epoch_check=False):
|
|
nonlocal last_model_save
|
|
nonlocal last_image_save
|
|
save_model_interval = args.save_embedding_every
|
|
save_image_interval = args.save_preview_every
|
|
save_completed = session_epoch >= max_train_epochs
|
|
save_canceled = status.interrupted
|
|
save_image = False
|
|
save_model = False
|
|
if not save_canceled and not save_completed:
|
|
# Check to see if the number of epochs since last save is gt the interval
|
|
if 0 < save_model_interval <= session_epoch - last_model_save:
|
|
save_model = True
|
|
last_model_save = session_epoch
|
|
|
|
# Repeat for sample images
|
|
if 0 < save_image_interval <= session_epoch - last_image_save:
|
|
save_image = True
|
|
last_image_save = session_epoch
|
|
|
|
else:
|
|
print("\nSave completed/canceled.")
|
|
if global_step > 0:
|
|
save_image = True
|
|
save_model = True
|
|
|
|
save_snapshot = False
|
|
save_lora = False
|
|
save_checkpoint = False
|
|
|
|
if is_epoch_check:
|
|
if shared.status.do_save_samples:
|
|
save_image = True
|
|
shared.status.do_save_samples = False
|
|
|
|
if shared.status.do_save_model:
|
|
save_model = True
|
|
shared.status.do_save_model = False
|
|
|
|
if save_model:
|
|
if save_canceled:
|
|
if global_step > 0:
|
|
print("Canceled, enabling saves.")
|
|
save_lora = args.save_lora_cancel
|
|
save_snapshot = args.save_state_cancel
|
|
save_checkpoint = args.save_ckpt_cancel
|
|
elif save_completed:
|
|
if global_step > 0:
|
|
print("Completed, enabling saves.")
|
|
save_lora = args.save_lora_after
|
|
save_snapshot = args.save_state_after
|
|
save_checkpoint = args.save_ckpt_after
|
|
else:
|
|
save_lora = args.save_lora_during
|
|
save_snapshot = args.save_state_during
|
|
save_checkpoint = args.save_ckpt_during
|
|
|
|
if (
|
|
save_checkpoint
|
|
or save_snapshot
|
|
or save_lora
|
|
or save_image
|
|
or save_model
|
|
):
|
|
save_weights(
|
|
save_image,
|
|
save_model,
|
|
save_snapshot,
|
|
save_checkpoint,
|
|
save_lora,
|
|
)
|
|
|
|
return save_model
|
|
|
|
def save_weights(
|
|
save_image, save_model, save_snapshot, save_checkpoint, save_lora
|
|
):
|
|
global last_samples
|
|
global last_prompts
|
|
nonlocal vae
|
|
|
|
printm(" Saving weights.")
|
|
pbar = mytqdm(
|
|
range(4),
|
|
desc="Saving weights",
|
|
disable=not accelerator.is_local_main_process,
|
|
position=1
|
|
)
|
|
pbar.set_postfix(refresh=True)
|
|
|
|
# Create the pipeline using the trained modules and save it.
|
|
if accelerator.is_main_process:
|
|
printm("Pre-cleanup.")
|
|
|
|
# Save random states so sample generation doesn't impact training.
|
|
if shared.device.type == 'cuda':
|
|
torch_rng_state = torch.get_rng_state()
|
|
cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda")
|
|
cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu")
|
|
|
|
optim_to(profiler, optimizer)
|
|
|
|
if profiler is not None:
|
|
cleanup()
|
|
|
|
if vae is None:
|
|
printm("Loading vae.")
|
|
vae = create_vae()
|
|
|
|
printm("Creating pipeline.")
|
|
|
|
s_pipeline = DiffusionPipeline.from_pretrained(
|
|
args.get_pretrained_model_name_or_path(),
|
|
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
|
|
text_encoder=accelerator.unwrap_model(
|
|
text_encoder, keep_fp32_wrapper=True
|
|
),
|
|
vae=vae,
|
|
torch_dtype=weight_dtype,
|
|
revision=args.revision,
|
|
safety_checker=None,
|
|
requires_safety_checker=None,
|
|
)
|
|
|
|
scheduler_class = get_scheduler_class(args.scheduler)
|
|
if args.attention == "xformers" and not shared.force_cpu:
|
|
xformerify(s_pipeline)
|
|
|
|
s_pipeline.scheduler = scheduler_class.from_config(
|
|
s_pipeline.scheduler.config
|
|
)
|
|
if "UniPC" in args.scheduler:
|
|
s_pipeline.scheduler.config.solver_type = "bh2"
|
|
|
|
s_pipeline = s_pipeline.to(accelerator.device)
|
|
|
|
printm("Patching model with tomesd.")
|
|
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
|
|
|
with accelerator.autocast(), torch.inference_mode():
|
|
if save_model:
|
|
tomesd.remove_patch(s_pipeline)
|
|
# We are saving weights, we need to ensure revision is saved
|
|
args.save()
|
|
try:
|
|
out_file = None
|
|
# Loras resume from pt
|
|
if not args.use_lora:
|
|
if save_snapshot:
|
|
pbar.set_description("Saving Snapshot")
|
|
status.textinfo = (
|
|
f"Saving snapshot at step {args.revision}..."
|
|
)
|
|
accelerator.save_state(
|
|
os.path.join(
|
|
args.model_dir,
|
|
"checkpoints",
|
|
f"checkpoint-{args.revision}",
|
|
)
|
|
)
|
|
pbar.update()
|
|
|
|
# We should save this regardless, because it's our fallback if no snapshot exists.
|
|
status.textinfo = (
|
|
f"Saving diffusion model at step {args.revision}..."
|
|
)
|
|
pbar.set_description("Saving diffusion model")
|
|
s_pipeline.save_pretrained(
|
|
os.path.join(args.model_dir, "working"),
|
|
safe_serialization=True,
|
|
)
|
|
if ema_model is not None:
|
|
ema_model.save_pretrained(
|
|
os.path.join(
|
|
args.get_pretrained_model_name_or_path(),
|
|
"ema_unet",
|
|
),
|
|
safe_serialization=True,
|
|
)
|
|
pbar.update()
|
|
|
|
printm("Patching model with tomesd.")
|
|
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
|
|
|
elif save_lora:
|
|
tomesd.remove_patch(s_pipeline)
|
|
pbar.set_description("Saving Lora Weights...")
|
|
# setup directory
|
|
loras_dir = os.path.join(args.model_dir, "loras")
|
|
os.makedirs(loras_dir, exist_ok=True)
|
|
# setup pt path
|
|
if args.custom_model_name == "":
|
|
lora_model_name = args.model_name
|
|
else:
|
|
lora_model_name = args.custom_model_name
|
|
lora_file_prefix = f"{lora_model_name}_{args.revision}"
|
|
out_file = os.path.join(
|
|
loras_dir, f"{lora_file_prefix}.pt"
|
|
)
|
|
# create pt
|
|
tgt_module = get_target_module(
|
|
"module", args.use_lora_extended
|
|
)
|
|
save_lora_weight(s_pipeline.unet, out_file, tgt_module)
|
|
|
|
modelmap = {"unet": (s_pipeline.unet, tgt_module)}
|
|
# save text_encoder
|
|
if stop_text_percentage != 0:
|
|
out_txt = out_file.replace(".pt", "_txt.pt")
|
|
modelmap["text_encoder"] = (
|
|
s_pipeline.text_encoder,
|
|
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
)
|
|
save_lora_weight(
|
|
s_pipeline.text_encoder,
|
|
out_txt,
|
|
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
|
)
|
|
pbar.update()
|
|
# save extra_net
|
|
if args.save_lora_for_extra_net:
|
|
os.makedirs(
|
|
shared.ui_lora_models_path, exist_ok=True
|
|
)
|
|
out_safe = os.path.join(
|
|
shared.ui_lora_models_path,
|
|
f"{lora_file_prefix}.safetensors",
|
|
)
|
|
save_extra_networks(modelmap, out_safe)
|
|
# package pt into checkpoint
|
|
if save_checkpoint:
|
|
pbar.set_description("Compiling Checkpoint")
|
|
snap_rev = str(args.revision) if save_snapshot else ""
|
|
if export_diffusers:
|
|
copy_diffusion_model(args.model_name, diffusers_dir)
|
|
else:
|
|
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file,
|
|
log=False, snap_rev=snap_rev, pbar=pbar)
|
|
printm("Restored, moved to acc.device.")
|
|
|
|
printm("Patching model with tomesd.")
|
|
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
|
|
|
except Exception as ex:
|
|
print(f"Exception saving checkpoint/model: {ex}")
|
|
traceback.print_exc()
|
|
pass
|
|
tomesd.remove_patch(s_pipeline)
|
|
save_dir = args.model_dir
|
|
if save_image:
|
|
samples = []
|
|
sample_prompts = []
|
|
last_samples = []
|
|
last_prompts = []
|
|
status.textinfo = (
|
|
f"Saving preview image(s) at step {args.revision}..."
|
|
)
|
|
try:
|
|
s_pipeline.set_progress_bar_config(disable=True)
|
|
sample_dir = os.path.join(save_dir, "samples")
|
|
os.makedirs(sample_dir, exist_ok=True)
|
|
with accelerator.autocast(), torch.inference_mode():
|
|
sd = SampleDataset(args)
|
|
prompts = sd.prompts
|
|
concepts = args.concepts()
|
|
if args.sanity_prompt:
|
|
epd = PromptData(
|
|
prompt=args.sanity_prompt,
|
|
seed=args.sanity_seed,
|
|
negative_prompt=concepts[
|
|
0
|
|
].save_sample_negative_prompt,
|
|
resolution=(args.resolution, args.resolution),
|
|
)
|
|
prompts.append(epd)
|
|
pbar.set_description("Generating Samples")
|
|
|
|
prompt_lengths = len(prompts)
|
|
if args.disable_logging:
|
|
pbar.reset(prompt_lengths)
|
|
else:
|
|
pbar.reset(prompt_lengths + 2)
|
|
|
|
ci = 0
|
|
for c in prompts:
|
|
c.out_dir = os.path.join(args.model_dir, "samples")
|
|
generator = torch.manual_seed(int(c.seed))
|
|
s_image = s_pipeline(
|
|
c.prompt,
|
|
num_inference_steps=c.steps,
|
|
guidance_scale=c.scale,
|
|
negative_prompt=c.negative_prompt,
|
|
height=c.resolution[1],
|
|
width=c.resolution[0],
|
|
generator=generator,
|
|
).images[0]
|
|
sample_prompts.append(c.prompt)
|
|
image_name = db_save_image(
|
|
s_image,
|
|
c,
|
|
custom_name=f"sample_{args.revision}-{ci}",
|
|
)
|
|
shared.status.current_image = image_name
|
|
shared.status.sample_prompts = [c.prompt]
|
|
samples.append(image_name)
|
|
pbar.update()
|
|
ci += 1
|
|
for sample in samples:
|
|
last_samples.append(sample)
|
|
for prompt in sample_prompts:
|
|
last_prompts.append(prompt)
|
|
del samples
|
|
del prompts
|
|
printm("Patching model with tomesd.")
|
|
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
|
except Exception as em:
|
|
print(f"Exception saving sample: {em}")
|
|
traceback.print_exc()
|
|
pass
|
|
printm("Starting cleanup.")
|
|
tomesd.remove_patch(s_pipeline)
|
|
del s_pipeline
|
|
if save_image:
|
|
if "generator" in locals():
|
|
del generator
|
|
|
|
if not args.disable_logging:
|
|
try:
|
|
printm("Parse logs.")
|
|
log_images, log_names = log_parser.parse_logs(
|
|
model_name=args.model_name
|
|
)
|
|
pbar.update()
|
|
for log_image in log_images:
|
|
last_samples.append(log_image)
|
|
for log_name in log_names:
|
|
last_prompts.append(log_name)
|
|
|
|
del log_images
|
|
del log_names
|
|
except Exception as l:
|
|
traceback.print_exc()
|
|
print(f"Exception parsing logz: {l}")
|
|
pass
|
|
|
|
send_training_update(
|
|
last_samples,
|
|
args.model_name,
|
|
last_prompts,
|
|
global_step,
|
|
args.revision
|
|
)
|
|
|
|
status.sample_prompts = last_prompts
|
|
status.current_image = last_samples
|
|
pbar.update()
|
|
|
|
if args.cache_latents:
|
|
printm("Unloading vae.")
|
|
del vae
|
|
# Preserve the reference again
|
|
vae = None
|
|
|
|
status.current_image = last_samples
|
|
printm("Cleanup.")
|
|
|
|
optim_to(profiler, optimizer, accelerator.device)
|
|
|
|
# Restore all random states to avoid having sampling impact training.
|
|
if shared.device.type == 'cuda':
|
|
torch.set_rng_state(torch_rng_state)
|
|
torch.cuda.set_rng_state(cuda_cpu_rng_state, device="cpu")
|
|
torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda")
|
|
|
|
cleanup()
|
|
printm("Completed saving weights.")
|
|
|
|
# Only show the progress bar once on each machine.
|
|
progress_bar = mytqdm(
|
|
range(global_step, max_train_steps),
|
|
disable=not accelerator.is_local_main_process,
|
|
position=0,
|
|
)
|
|
progress_bar.set_description("Steps")
|
|
progress_bar.set_postfix(refresh=True)
|
|
args.revision = (
|
|
args.revision if isinstance(args.revision, int) else
|
|
int(args.revision) if str(args.revision).strip() else
|
|
0
|
|
)
|
|
lifetime_step = args.revision
|
|
lifetime_epoch = args.epoch
|
|
status.job_count = max_train_steps
|
|
status.job_no = global_step
|
|
training_complete = False
|
|
msg = ""
|
|
|
|
last_tenc = 0 < text_encoder_epochs
|
|
if stop_text_percentage == 0:
|
|
last_tenc = False
|
|
|
|
for epoch in range(first_epoch, max_train_epochs):
|
|
if training_complete:
|
|
print("Training complete, breaking epoch.")
|
|
break
|
|
|
|
if args.train_unet:
|
|
unet.train()
|
|
|
|
train_tenc = epoch < text_encoder_epochs
|
|
if stop_text_percentage == 0:
|
|
train_tenc = False
|
|
|
|
if args.freeze_clip_normalization:
|
|
text_encoder.eval()
|
|
else:
|
|
text_encoder.train(train_tenc)
|
|
|
|
if not args.use_lora:
|
|
text_encoder.requires_grad_(train_tenc)
|
|
elif train_tenc:
|
|
text_encoder.text_model.embeddings.requires_grad_(True)
|
|
|
|
if last_tenc != train_tenc:
|
|
last_tenc = train_tenc
|
|
cleanup()
|
|
|
|
loss_total = 0
|
|
|
|
current_prior_loss_weight = current_prior_loss(
|
|
args, current_epoch=global_epoch
|
|
)
|
|
for step, batch in enumerate(train_dataloader):
|
|
# Skip steps until we reach the resumed step
|
|
if (
|
|
resume_from_checkpoint
|
|
and epoch == first_epoch
|
|
and step < resume_step
|
|
):
|
|
progress_bar.update(train_batch_size)
|
|
progress_bar.reset()
|
|
status.job_count = max_train_steps
|
|
status.job_no += train_batch_size
|
|
continue
|
|
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
|
|
# Convert images to latent space
|
|
with torch.no_grad():
|
|
if args.cache_latents:
|
|
latents = batch["images"].to(accelerator.device)
|
|
else:
|
|
latents = vae.encode(
|
|
batch["images"].to(dtype=weight_dtype)
|
|
).latent_dist.sample()
|
|
latents = latents * 0.18215
|
|
|
|
# Sample noise that we'll add to the latents
|
|
if args.offset_noise < 0:
|
|
noise = torch.randn_like(latents, device=latents.device)
|
|
else:
|
|
noise = torch.randn_like(
|
|
latents, device=latents.device
|
|
) + args.offset_noise * torch.randn(
|
|
latents.shape[0],
|
|
latents.shape[1],
|
|
1,
|
|
1,
|
|
device=latents.device,
|
|
)
|
|
b_size = latents.shape[0]
|
|
|
|
# Sample a random timestep for each image
|
|
timesteps = torch.randint(
|
|
0,
|
|
noise_scheduler.config.num_train_timesteps,
|
|
(b_size,),
|
|
device=latents.device,
|
|
)
|
|
timesteps = timesteps.long()
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
pad_tokens = args.pad_tokens if train_tenc else False
|
|
encoder_hidden_states = encode_hidden_state(
|
|
text_encoder,
|
|
batch["input_ids"],
|
|
pad_tokens,
|
|
b_size,
|
|
args.max_token_length,
|
|
tokenizer.model_max_length,
|
|
args.clip_skip,
|
|
)
|
|
|
|
# Predict the noise residual
|
|
if args.use_ema and args.ema_predict:
|
|
noise_pred = ema_model(
|
|
noisy_latents, timesteps, encoder_hidden_states
|
|
).sample
|
|
else:
|
|
noise_pred = unet(
|
|
noisy_latents, timesteps, encoder_hidden_states
|
|
).sample
|
|
|
|
# Get the target for loss depending on the prediction type
|
|
if noise_scheduler.config.prediction_type == "v_prediction":
|
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
else:
|
|
target = noise
|
|
|
|
if not args.split_loss:
|
|
loss = instance_loss = torch.nn.functional.mse_loss(
|
|
noise_pred.float(), target.float(), reduction="mean"
|
|
)
|
|
loss *= batch["loss_avg"]
|
|
|
|
else:
|
|
model_pred_chunks = torch.split(noise_pred, 1, dim=0)
|
|
target_pred_chunks = torch.split(target, 1, dim=0)
|
|
instance_chunks = []
|
|
prior_chunks = []
|
|
instance_pred_chunks = []
|
|
prior_pred_chunks = []
|
|
|
|
# Iterate over the list of boolean values in batch["types"]
|
|
for i, is_prior in enumerate(batch["types"]):
|
|
# If is_prior is False, append the corresponding chunk to instance_chunks
|
|
if not is_prior:
|
|
instance_chunks.append(model_pred_chunks[i])
|
|
instance_pred_chunks.append(target_pred_chunks[i])
|
|
# If is_prior is True, append the corresponding chunk to prior_chunks
|
|
else:
|
|
prior_chunks.append(model_pred_chunks[i])
|
|
prior_pred_chunks.append(target_pred_chunks[i])
|
|
|
|
# initialize with 0 in case we are having batch = 1
|
|
instance_loss = torch.tensor(0)
|
|
prior_loss = torch.tensor(0)
|
|
|
|
# Concatenate the chunks in instance_chunks to form the model_pred_instance tensor
|
|
if len(instance_chunks):
|
|
model_pred = torch.stack(instance_chunks, dim=0)
|
|
target = torch.stack(instance_pred_chunks, dim=0)
|
|
instance_loss = torch.nn.functional.mse_loss(
|
|
model_pred.float(), target.float(), reduction="mean"
|
|
)
|
|
|
|
if len(prior_pred_chunks):
|
|
model_pred_prior = torch.stack(prior_chunks, dim=0)
|
|
target_prior = torch.stack(prior_pred_chunks, dim=0)
|
|
prior_loss = torch.nn.functional.mse_loss(
|
|
model_pred_prior.float(),
|
|
target_prior.float(),
|
|
reduction="mean",
|
|
)
|
|
|
|
if len(instance_chunks) and len(prior_chunks):
|
|
# Add the prior loss to the instance loss.
|
|
loss = instance_loss + current_prior_loss_weight * prior_loss
|
|
elif len(instance_chunks):
|
|
loss = instance_loss
|
|
else:
|
|
loss = prior_loss * current_prior_loss_weight
|
|
|
|
accelerator.backward(loss)
|
|
|
|
if accelerator.sync_gradients and not args.use_lora:
|
|
if train_tenc:
|
|
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
|
|
else:
|
|
params_to_clip = unet.parameters()
|
|
accelerator.clip_grad_norm_(params_to_clip, 1)
|
|
|
|
optimizer.step()
|
|
lr_scheduler.step(train_batch_size)
|
|
if args.use_ema and ema_model is not None:
|
|
ema_model.step(unet)
|
|
if profiler is not None:
|
|
profiler.step()
|
|
|
|
optimizer.zero_grad(set_to_none=args.gradient_set_to_none)
|
|
|
|
#Track current step and epoch for OOM resume
|
|
#shared.in_progress_epoch = global_epoch
|
|
#shared.in_progress_steps = global_step
|
|
|
|
allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
|
|
cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
|
|
last_lr = lr_scheduler.get_last_lr()[0]
|
|
|
|
global_step += train_batch_size
|
|
args.revision += train_batch_size
|
|
status.job_no += train_batch_size
|
|
|
|
del noise_pred
|
|
del latents
|
|
del encoder_hidden_states
|
|
del noise
|
|
del timesteps
|
|
del noisy_latents
|
|
del target
|
|
|
|
if dadapt(args.optimizer):
|
|
dlr_unet = optimizer.param_groups[0]["d"]*optimizer.param_groups[0]["lr"]
|
|
dlr_tenc = optimizer.param_groups[1]["d"]*optimizer.param_groups[1]["lr"]
|
|
|
|
loss_step = loss.detach().item()
|
|
loss_total += loss_step
|
|
|
|
if args.split_loss:
|
|
if dadapt(args.optimizer):
|
|
logs = {
|
|
"lr": float(dlr_unet),
|
|
#"dlr_tenc": float(dlr_tenc),
|
|
"loss": float(loss_step),
|
|
"inst_loss": float(instance_loss.detach().item()),
|
|
"prior_loss": float(prior_loss.detach().item()),
|
|
"vram": float(cached),
|
|
}
|
|
else:
|
|
logs = {
|
|
"lr": float(last_lr),
|
|
"loss": float(loss_step),
|
|
"inst_loss": float(instance_loss.detach().item()),
|
|
"prior_loss": float(prior_loss.detach().item()),
|
|
"vram": float(cached),
|
|
}
|
|
|
|
else:
|
|
if dadapt(args.optimizer):
|
|
logs = {
|
|
"lr": float(dlr_unet),
|
|
#"dlr_tenc": float(dlr_tenc),
|
|
"loss": float(loss_step),
|
|
"vram": float(cached),
|
|
}
|
|
else:
|
|
logs = {
|
|
"lr": float(last_lr),
|
|
"loss": float(loss_step),
|
|
"vram": float(cached),
|
|
}
|
|
|
|
|
|
if dadapt(args.optimizer):
|
|
status.textinfo2 = (
|
|
f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(dlr_unet))}, TENC DLR: {'{:.2E}'.format(Decimal(dlr_tenc))}, "
|
|
f"VRAM: {allocated}/{cached} GB"
|
|
)
|
|
else:
|
|
status.textinfo2 = (
|
|
f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, "
|
|
f"VRAM: {allocated}/{cached} GB"
|
|
)
|
|
progress_bar.update(train_batch_size)
|
|
progress_bar.set_postfix(**logs)
|
|
accelerator.log(logs, step=args.revision)
|
|
|
|
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
|
accelerator.log(logs, step=global_step)
|
|
|
|
status.job_count = max_train_steps
|
|
status.job_no = global_step
|
|
status.textinfo = (
|
|
f"Steps: {global_step}/{max_train_steps} (Current),"
|
|
f" {args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}"
|
|
)
|
|
|
|
if math.isnan(loss_step):
|
|
print("Loss is NaN, your model is dead. Cancelling training.")
|
|
status.interrupted = True
|
|
|
|
# Log completion message
|
|
if training_complete or status.interrupted:
|
|
shared.in_progress = False
|
|
shared.in_progress_step = 0
|
|
shared.in_progress_epoch - 0
|
|
print(" Training complete (step check).")
|
|
if status.interrupted:
|
|
state = "cancelled"
|
|
else:
|
|
state = "complete"
|
|
|
|
status.textinfo = (
|
|
f"Training {state} {global_step}/{max_train_steps}, {args.revision}"
|
|
f" total."
|
|
)
|
|
|
|
break
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
args.epoch += 1
|
|
global_epoch += 1
|
|
lifetime_epoch += 1
|
|
session_epoch += 1
|
|
lr_scheduler.step(is_epoch=True)
|
|
status.job_count = max_train_steps
|
|
status.job_no = global_step
|
|
|
|
check_save(True)
|
|
|
|
if args.num_train_epochs > 1:
|
|
training_complete = session_epoch >= max_train_epochs
|
|
|
|
if training_complete or status.interrupted:
|
|
print(" Training complete (step check).")
|
|
if status.interrupted:
|
|
state = "cancelled"
|
|
else:
|
|
state = "complete"
|
|
|
|
status.textinfo = (
|
|
f"Training {state} {global_step}/{max_train_steps}, {args.revision}"
|
|
f" total."
|
|
)
|
|
|
|
break
|
|
|
|
# Do this at the very END of the epoch, only after we're sure we're not done
|
|
if args.epoch_pause_frequency > 0 and args.epoch_pause_time > 0:
|
|
if not session_epoch % args.epoch_pause_frequency:
|
|
print(
|
|
f"Giving the GPU a break for {args.epoch_pause_time} seconds."
|
|
)
|
|
for i in range(args.epoch_pause_time):
|
|
if status.interrupted:
|
|
training_complete = True
|
|
print("Training complete, interrupted.")
|
|
shared.in_progress = False
|
|
shared.in_progress_step = 0
|
|
shared.in_progress_epoch = 0
|
|
break
|
|
time.sleep(1)
|
|
|
|
cleanup_memory()
|
|
accelerator.end_training()
|
|
result.msg = msg
|
|
result.config = args
|
|
result.samples = last_samples
|
|
stop_profiler(profiler)
|
|
return result
|
|
|
|
return inner_loop()
|