314 lines
16 KiB
Python
314 lines
16 KiB
Python
import csv
|
|
import datetime
|
|
import gc
|
|
import glob
|
|
import html
|
|
import os
|
|
import sys
|
|
import traceback
|
|
import inspect
|
|
|
|
import torch
|
|
import tqdm
|
|
|
|
|
|
from modules import shared, sd_models, devices, processing, sd_samplers
|
|
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork
|
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
from .textual_inversion import validate_train_inputs, write_loss
|
|
from ..hypernetwork import Hypernetwork, load_hypernetwork
|
|
from . import sd_hijack_checkpoint
|
|
from ..hnutil import optim_to
|
|
from ..scheduler import CosineAnnealingWarmUpRestarts
|
|
from .dataset import PersonalizedBase,PersonalizedDataLoader
|
|
|
|
|
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
|
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
|
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
|
|
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
|
|
preview_width, preview_height,
|
|
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False):
|
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
|
from modules import images
|
|
try:
|
|
if use_beta_scheduler:
|
|
print("Using Beta Scheduler")
|
|
beta_repeat_epoch = int(beta_repeat_epoch)
|
|
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
|
|
min_lr = float(min_lr)
|
|
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
|
|
gamma_rate = float(gamma_rate)
|
|
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
|
|
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
|
|
epoch_mult = int(float(epoch_mult))
|
|
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
|
|
warmup = int(warmup)
|
|
assert warmup >= 1, "Warmup epoch should be larger than 0!"
|
|
else:
|
|
save_when_converge = False
|
|
create_when_converge = False
|
|
except ValueError:
|
|
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
|
create_image_every = create_image_every or 0
|
|
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
|
template_file, steps, save_hypernetwork_every, create_image_every,
|
|
log_directory, name="hypernetwork")
|
|
|
|
load_hypernetwork(hypernetwork_name)
|
|
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
|
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
|
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
|
|
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
|
shared.state.job_count = steps
|
|
|
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
|
|
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
|
unload = shared.opts.unload_models_when_training
|
|
|
|
if save_hypernetwork_every > 0 or save_when_converge:
|
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
|
os.makedirs(hypernetwork_dir, exist_ok=True)
|
|
else:
|
|
hypernetwork_dir = None
|
|
|
|
if create_image_every > 0 or create_when_converge:
|
|
images_dir = os.path.join(log_directory, "images")
|
|
os.makedirs(images_dir, exist_ok=True)
|
|
else:
|
|
images_dir = None
|
|
|
|
hypernetwork = shared.loaded_hypernetwork
|
|
checkpoint = sd_models.select_checkpoint()
|
|
|
|
initial_step = hypernetwork.step or 0
|
|
if initial_step >= steps:
|
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
|
return hypernetwork, filename
|
|
|
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
|
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
|
|
pin_memory = shared.opts.pin_memory
|
|
|
|
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
|
height=training_height,
|
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
|
placeholder_token=hypernetwork_name, model=shared.sd_model,
|
|
cond_model=shared.sd_model.cond_stage_model,
|
|
device=devices.device, template_file=template_file,
|
|
include_cond=True, batch_size=batch_size,
|
|
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
|
tag_drop_out=tag_drop_out,
|
|
latent_sampling_method=latent_sampling_method)
|
|
|
|
latent_sampling_method = ds.latent_sampling_method
|
|
|
|
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
|
|
batch_size=ds.batch_size, pin_memory=pin_memory)
|
|
|
|
if unload:
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
|
|
weights = hypernetwork.weights(True)
|
|
|
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
|
if hypernetwork.optimizer_name in optimizer_dict:
|
|
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
|
optimizer_name = hypernetwork.optimizer_name
|
|
else:
|
|
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
|
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
|
optimizer_name = 'AdamW'
|
|
|
|
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
|
try:
|
|
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
|
except RuntimeError as e:
|
|
print("Cannot resume from saved optimizer!")
|
|
print(e)
|
|
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, min_lr=min_lr, gamma=gamma_rate)
|
|
scheduler_beta.last_epoch =hypernetwork.step-1
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
batch_size = ds.batch_size
|
|
gradient_step = ds.gradient_step
|
|
# n steps = batch_size * gradient_step * n image processed
|
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
|
loss_step = 0
|
|
_loss_step = 0 # internal
|
|
# size = len(ds.indexes)
|
|
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
|
# losses = torch.zeros((size,))
|
|
# previous_mean_losses = [0]
|
|
# previous_mean_loss = 0
|
|
# print("Mean loss of {} elements".format(size))
|
|
|
|
steps_without_grad = 0
|
|
|
|
last_saved_file = "<none>"
|
|
last_saved_image = "<none>"
|
|
forced_filename = "<none>"
|
|
|
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
|
try:
|
|
for i in range((steps - initial_step) * gradient_step):
|
|
if scheduler.finished or hypernetwork.step > steps:
|
|
break
|
|
if shared.state.interrupted:
|
|
break
|
|
for j, batch in enumerate(dl):
|
|
# works as a drop_last=True for gradient accumulation
|
|
if j == max_steps_per_epoch:
|
|
break
|
|
if use_beta_scheduler:
|
|
scheduler_beta.step(hypernetwork.step)
|
|
else:
|
|
scheduler.apply(optimizer, hypernetwork.step)
|
|
if scheduler.finished:
|
|
break
|
|
if shared.state.interrupted:
|
|
break
|
|
|
|
with torch.autocast("cuda"):
|
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
if tag_drop_out != 0 or shuffle_tags:
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
|
|
non_blocking=pin_memory)
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
else:
|
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
|
del x
|
|
del c
|
|
|
|
_loss_step += loss.item()
|
|
scaler.scale(loss).backward()
|
|
batch.latent_sample.to(devices.cpu)
|
|
del loss
|
|
# go back until we reach gradient accumulation steps
|
|
if (j + 1) % gradient_step != 0:
|
|
continue
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
|
# scaler.unscale_(optimizer)
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
hypernetwork.step += 1
|
|
pbar.update()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss_step = _loss_step
|
|
_loss_step = 0
|
|
|
|
steps_done = hypernetwork.step + 1
|
|
|
|
epoch_num = hypernetwork.step // steps_per_epoch
|
|
epoch_step = hypernetwork.step % steps_per_epoch
|
|
|
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
|
if hypernetwork_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC() and save_when_converge) or (save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)):
|
|
# Before saving, change name to match current checkpoint.
|
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
|
hypernetwork.optimizer_name = optimizer_name
|
|
if shared.opts.save_optimizer_state:
|
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
|
|
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
|
|
{
|
|
"loss": f"{loss_step:.7f}",
|
|
"learn_rate": optimizer.param_groups[0]['lr']
|
|
})
|
|
|
|
if images_dir is not None and (use_beta_scheduler and scheduler_beta.is_EOC() and create_when_converge) or (create_image_every > 0 and steps_done % create_image_every == 0):
|
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
hypernetwork.eval()
|
|
optim_to(optimizer, devices.cpu)
|
|
gc.collect()
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
|
|
p = processing.StableDiffusionProcessingTxt2Img(
|
|
sd_model=shared.sd_model,
|
|
do_not_save_grid=True,
|
|
do_not_save_samples=True,
|
|
)
|
|
|
|
if preview_from_txt2img:
|
|
p.prompt = preview_prompt
|
|
p.negative_prompt = preview_negative_prompt
|
|
p.steps = preview_steps
|
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
|
p.cfg_scale = preview_cfg_scale
|
|
p.seed = preview_seed
|
|
p.width = preview_width
|
|
p.height = preview_height
|
|
else:
|
|
p.prompt = batch.cond_text[0]
|
|
p.steps = 20
|
|
p.width = training_width
|
|
p.height = training_height
|
|
|
|
preview_text = p.prompt
|
|
|
|
processed = processing.process_images(p)
|
|
image = processed.images[0] if len(processed.images) > 0 else None
|
|
|
|
if unload:
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
hypernetwork.train()
|
|
optim_to(optimizer, devices.device)
|
|
if image is not None:
|
|
shared.state.current_image = image
|
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
|
shared.opts.samples_format,
|
|
processed.infotexts[0], p=p,
|
|
forced_filename=forced_filename,
|
|
save_to_dirs=False)
|
|
last_saved_image += f", prompt: {preview_text}"
|
|
|
|
shared.state.job_no = hypernetwork.step
|
|
|
|
shared.state.textinfo = f"""
|
|
<p>
|
|
Loss: {loss_step:.7f}<br/>
|
|
Step: {steps_done}<br/>
|
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
|
</p>
|
|
"""
|
|
except Exception:
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
finally:
|
|
pbar.leave = False
|
|
pbar.close()
|
|
hypernetwork.eval()
|
|
# report_statistics(loss_dict)
|
|
|
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
|
hypernetwork.optimizer_name = optimizer_name
|
|
if shared.opts.save_optimizer_state:
|
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
|
del optimizer
|
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
|
|
return hypernetwork, filename |