Hypernetwork-MonkeyPatch-Ex.../patches/external_pr/hypernetwork.py

314 lines
16 KiB
Python

import csv
import datetime
import gc
import glob
import html
import os
import sys
import traceback
import inspect
import torch
import tqdm
from modules import shared, sd_models, devices, processing, sd_samplers
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from .textual_inversion import validate_train_inputs, write_loss
from ..hypernetwork import Hypernetwork, load_hypernetwork
from . import sd_hijack_checkpoint
from ..hnutil import optim_to
from ..scheduler import CosineAnnealingWarmUpRestarts
from .dataset import PersonalizedBase,PersonalizedDataLoader
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
preview_width, preview_height,
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
try:
if use_beta_scheduler:
print("Using Beta Scheduler")
beta_repeat_epoch = int(beta_repeat_epoch)
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
min_lr = float(min_lr)
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
gamma_rate = float(gamma_rate)
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
epoch_mult = int(float(epoch_mult))
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
warmup = int(warmup)
assert warmup >= 1, "Warmup epoch should be larger than 0!"
else:
save_when_converge = False
create_when_converge = False
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0 or save_when_converge:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0 or create_when_converge:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name, model=shared.sd_model,
cond_model=shared.sd_model.cond_stage_model,
device=devices.device, template_file=template_file,
include_cond=True, batch_size=batch_size,
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
tag_drop_out=tag_drop_out,
latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, min_lr=min_lr, gamma=gamma_rate)
scheduler_beta.last_epoch =hypernetwork.step-1
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 # internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
if scheduler.finished or hypernetwork.step > steps:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
if use_beta_scheduler:
scheduler_beta.step(hypernetwork.step)
else:
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
del c
_loss_step += loss.item()
scaler.scale(loss).backward()
batch.latent_sample.to(devices.cpu)
del loss
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if hypernetwork_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC() and save_when_converge) or (save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)):
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
{
"loss": f"{loss_step:.7f}",
"learn_rate": optimizer.param_groups[0]['lr']
})
if images_dir is not None and (use_beta_scheduler and scheduler_beta.is_EOC() and create_when_converge) or (create_image_every > 0 and steps_done % create_image_every == 0):
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval()
optim_to(optimizer, devices.cpu)
gc.collect()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train()
optim_to(optimizer, devices.device)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format,
processed.infotexts[0], p=p,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval()
# report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename