hypernetwork-tuning
aria1th 2023-01-16 05:37:19 +09:00
parent 7631f51b42
commit dceaa667c2
4 changed files with 711 additions and 12 deletions

View File

@ -1,12 +1,11 @@
import csv
import datetime import datetime
import gc import gc
import glob
import html import html
import json
import os import os
import sys import sys
import time
import traceback import traceback
import inspect
from collections import defaultdict, deque from collections import defaultdict, deque
import torch import torch
@ -21,9 +20,21 @@ from .textual_inversion import validate_train_inputs, write_loss
from ..hypernetwork import Hypernetwork, load_hypernetwork from ..hypernetwork import Hypernetwork, load_hypernetwork
from . import sd_hijack_checkpoint from . import sd_hijack_checkpoint
from ..hnutil import optim_to from ..hnutil import optim_to
from ..ui import create_hypernetwork_load
from ..scheduler import CosineAnnealingWarmUpRestarts from ..scheduler import CosineAnnealingWarmUpRestarts
from .dataset import PersonalizedBase,PersonalizedDataLoader from .dataset import PersonalizedBase,PersonalizedDataLoader
def get_training_option(filename):
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename)
elif os.path.exists(filename):
filename = filename
else:
return False
print(f"Loading setting from {filename}!")
obj = json.load(filename)
return obj
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
@ -33,10 +44,40 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False, use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False,
move_optimizer=True, move_optimizer=True,
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,adamw_eps=1e-8, use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,adamw_eps=1e-8,
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2): use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2,
load_training_options = ''):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
if load_training_options != '':
dump: dict = get_training_option(load_training_options)
if dump and dump is not None:
learn_rate = dump['hypernetwork_learn_rate']
batch_size = dump['batch_size']
gradient_step = dump['gradient_step']
training_width = dump['training_width']
training_height = dump['training_height']
steps = dump['steps']
shuffle_tags = dump['shuffle_tags']
tag_drop_out = dump['tag_drop_out']
save_when_converge = dump['save_when_converge']
create_when_converge = dump['create_when_converge']
latent_sampling_method = dump['latent_sampling_method']
template_file = dump['template_file']
use_beta_scheduler = dump['use_beta_scheduler']
beta_repeat_epoch = dump['beta_repeat_epoch']
epoch_mult = dump['epoch_mult']
warmup = dump['warmup']
min_lr = dump['min_lr']
gamma_rate = dump['gamma_rate']
use_adamw_parameter = dump['use_beta_adamW_checkbox']
adamw_weight_decay = dump['adamw_weight_decay']
adamw_beta_1 = dump['adamw_beta_1']
adamw_beta_2 = dump['adamw_beta_2']
adamw_eps = dump['adamw_eps']
use_grad_opts = dump['show_gradient_clip_checkbox']
gradient_clip_opt = dump['gradient_clip_opt']
optional_gradient_clip_value = dump['optional_gradient_clip_value']
optional_gradient_norm_type = dump['optional_gradient_norm_type']
try: try:
if use_adamw_parameter: if use_adamw_parameter:
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]] adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
@ -410,3 +451,458 @@ Last saved image: {html.escape(last_saved_image)}<br/>
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename return hypernetwork, filename
def internal_clean_training(hypernetwork_name, data_root, log_directory,
create_image_every, save_hypernetwork_every,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
move_optimizer=True,
load_hypernetworks_option='', load_training_options=''):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
if load_hypernetworks_option != '':
timeStr = time.asctime()[:19]
dump_hyper: dict = get_training_option(load_hypernetworks_option)
hypernetwork_name = hypernetwork_name + timeStr
enable_sizes = dump_hyper['enable_sizes']
overwrite_old = dump_hyper['overwrite_old']
layer_structure = dump_hyper['layer_structure']
activation_func = dump_hyper['activation_func']
weight_init = dump_hyper['weight_init']
add_layer_norm = dump_hyper['add_layer_norm']
use_dropout = dump_hyper['use_dropout']
dropout_structure = dump_hyper['dropout_structure']
optional_info = dump_hyper['optional_info']
weight_init_seed = dump_hyper['weight_init_seed']
normal_std = dump_hyper['normal_std']
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std)
else:
load_hypernetwork(hypernetwork_name)
hypernetwork_name = hypernetwork_name + time.asctime()[:19]
shared.loaded_hypernetwork.save(hypernetwork_name)
load_hypernetwork(hypernetwork_name)
if load_training_options != '':
dump: dict = get_training_option(load_training_options)
if dump and dump is not None:
learn_rate = dump['hypernetwork_learn_rate']
batch_size = dump['batch_size']
gradient_step = dump['gradient_step']
training_width = dump['training_width']
training_height = dump['training_height']
steps = dump['steps']
shuffle_tags = dump['shuffle_tags']
tag_drop_out = dump['tag_drop_out']
save_when_converge = dump['save_when_converge']
create_when_converge = dump['create_when_converge']
latent_sampling_method = dump['latent_sampling_method']
template_file = dump['template_file']
use_beta_scheduler = dump['use_beta_scheduler']
beta_repeat_epoch = dump['beta_repeat_epoch']
epoch_mult = dump['epoch_mult']
warmup = dump['warmup']
min_lr = dump['min_lr']
gamma_rate = dump['gamma_rate']
use_adamw_parameter = dump['use_beta_adamW_checkbox']
adamw_weight_decay = dump['adamw_weight_decay']
adamw_beta_1 = dump['adamw_beta_1']
adamw_beta_2 = dump['adamw_beta_2']
adamw_eps = dump['adamw_eps']
use_grad_opts = dump['show_gradient_clip_checkbox']
gradient_clip_opt = dump['gradient_clip_opt']
optional_gradient_clip_value = dump['optional_gradient_clip_value']
optional_gradient_norm_type = dump['optional_gradient_norm_type']
else:
raise RuntimeError(f"Cannot load from {load_training_options}!")
else:
raise RuntimeError(f"Cannot load from {load_training_options}!")
try:
if use_adamw_parameter:
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!"
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, adamw_eps])), "Cannot use negative or >1 number for adamW parameters!"
adamW_kwarg_dict = {
'weight_decay' : adamw_weight_decay,
'betas' : (adamw_beta_1, adamw_beta_2),
'eps' : adamw_eps
}
print('Using custom AdamW parameters')
else:
adamW_kwarg_dict = {
'weight_decay' : 0.01,
'betas' : (0.9, 0.99),
'eps' : 1e-8
}
if use_beta_scheduler:
print("Using Beta Scheduler")
beta_repeat_epoch = int(beta_repeat_epoch)
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
min_lr = float(min_lr)
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
gamma_rate = float(gamma_rate)
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
epoch_mult = float(epoch_mult)
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
warmup = int(warmup)
assert warmup >= 1, "Warmup epoch should be larger than 0!"
print(f"Save when converges : {save_when_converge}")
print(f"Generate image when converges : {create_when_converge}")
else:
beta_repeat_epoch = 4000
epoch_mult=1
warmup=10
min_lr=1e-7
gamma_rate=1
save_when_converge = False
create_when_converge = False
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
if use_grad_opts and gradient_clip_opt != "None":
try:
optional_gradient_clip_value = float(optional_gradient_clip_value)
except ValueError:
raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})")
if gradient_clip_opt == "Norm":
try:
grad_norm = int(optional_gradient_norm_type)
except ValueError:
raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})")
assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}"
print(f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}")
def gradient_clipping(arg1):
torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type)
return
else:
print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}")
def gradient_clipping(arg1):
torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value)
return
else:
def gradient_clipping(arg1):
return
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0 or save_when_converge:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0 or create_when_converge:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
if shared.opts.training_enable_tensorboard:
print("Tensorboard logging enabled")
tensorboard_writer = tensorboard_setup(log_directory)
else:
tensorboard_writer = None
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
detach_grad = shared.opts.disable_ema # test code that removes EMA
if detach_grad:
print("Disabling training for staged models!")
shared.sd_model.cond_stage_model.requires_grad_(False)
shared.sd_model.first_stage_model.requires_grad_(False)
torch.cuda.empty_cache()
pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name, model=shared.sd_model,
cond_model=shared.sd_model.cond_stage_model,
device=devices.device, template_file=template_file,
include_cond=True, batch_size=batch_size,
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
tag_drop_out=tag_drop_out,
latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
batch_size=ds.batch_size, pin_memory=pin_memory)
old_parallel_processing_allowed = shared.parallel_processing_allowed
if unload:
shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
if use_adamw_parameter:
if hypernetwork.optimizer_name != 'AdamW':
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
else:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
if use_beta_scheduler:
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
scheduler_beta.last_epoch =hypernetwork.step-1
else:
scheduler_beta = None
for pg in optimizer.param_groups:
pg['lr'] = scheduler.learn_rate
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 # internal
# size = len(ds.indexes)
loss_dict = defaultdict(lambda: deque(maxlen=1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
if scheduler.finished or hypernetwork.step > steps:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
if use_beta_scheduler:
scheduler_beta.step(hypernetwork.step)
else:
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0]
for filenames in batch.filename:
loss_dict[filenames].append(loss.item())
loss /= gradient_step
del x
del c
_loss_step += loss.item()
scaler.scale(loss).backward()
batch.latent_sample.to(devices.cpu)
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
gradient_clipping(weights)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
try:
scaler.step(optimizer)
except AssertionError:
optimizer.param_groups[0]['capturable'] = True
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}"
pbar.set_description(description)
if hypernetwork_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or (save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)):
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
{
"loss": f"{loss_step:.7f}",
"learn_rate": optimizer.param_groups[0]['lr']
})
if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
if images_dir is not None and (use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (create_image_every > 0 and steps_done % create_image_every == 0):
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
rng_state = torch.get_rng_state()
cuda_rng_state = None
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
hypernetwork.eval()
if move_optimizer:
optim_to(optimizer, devices.cpu)
gc.collect()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
torch.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if move_optimizer:
optim_to(optimizer, devices.device)
if image is not None:
shared.state.assign_current_image(image)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format,
processed.infotexts[0], p=p,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval()
shared.parallel_processing_allowed = old_parallel_processing_allowed
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename
def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directory,
create_image_every, save_hypernetwork_every, preview_from_txt2img, preview_prompt,
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
preview_width, preview_height,
move_optimizer=True,
optional_new_hypernetwork_name='', load_hypernetworks_options='', load_training_options=''):
load_hypernetworks_options = load_hypernetworks_options.split(',')
load_training_options = load_training_options.split(',')
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
for load_hypernetworks_option in load_hypernetworks_options:
for load_training_option in load_training_options:
internal_clean_training(hypernetwork_name if load_hypernetworks_option != '' else optional_new_hypernetwork_name, data_root, log_directory,
create_image_every, save_hypernetwork_every,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
move_optimizer,
load_hypernetworks_option, load_training_option)
if shared.state.interrupted:
return

View File

@ -1,12 +1,16 @@
import html import html
import json
import os import os
import random
from modules import shared, sd_hijack, devices from modules import shared, sd_hijack, devices
from modules.call_queue import wrap_gradio_call
from modules.hypernetworks.ui import keys
from modules.paths import script_path from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show from modules.ui import create_refresh_button, gr_show
from webui import wrap_gradio_gpu_call from webui import wrap_gradio_gpu_call
from .textual_inversion import train_embedding as train_embedding_external from .textual_inversion import train_embedding as train_embedding_external
from .hypernetwork import train_hypernetwork as train_hypernetwork_external from .hypernetwork import train_hypernetwork as train_hypernetwork_external, train_hypernetwork_tuning
import gradio as gr import gradio as gr
@ -35,6 +39,51 @@ Hypernetwork saved to {html.escape(filename)}
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
def train_hypernetwork_ui_tuning(*args):
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
sd_hijack.undo_optimizations()
hypernetwork, filename = train_hypernetwork_tuning(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
Hypernetwork saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
def save_training_setting(*args):
save_file_name, hypernetwork_learn_rate, batch_size, gradient_step, training_width, \
training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, \
template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \
gamma_rate, use_beta_adamW_checkbox, save_converge_opt, generate_converge_opt, \
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type = args
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
with open(filename, 'w') as file:
json.dump(locals(), file)
print(f"File saved as {filename}")
def save_hypernetwork_setting(*args):
save_file_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std = args
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_hypernetwork_' + '.json'
with open(filename, 'w') as file:
json.dump(locals(), file)
print(f"File saved as {filename}")
def on_train_gamma_tab(params=None): def on_train_gamma_tab(params=None):
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with gr.Tab(label="Train_Gamma") as train_gamma: with gr.Tab(label="Train_Gamma") as train_gamma:
@ -140,8 +189,45 @@ def on_train_gamma_tab(params=None):
train_embedding = gr.Button(value="Train Embedding", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary')
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False) ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error3", value="") ti_outcome = gr.HTML(elem_id="ti_error3", value="")
save_training_option = gr.Button(label="Save training setting")
save_file_name = gr.Textbox(label="File name to save setting as", value="")
load_training_option = gr.Textbox(label="Load training option from saved json file. This will override settings above", value="")
#Full path to .json or simple names are recommended.
save_training_option.click(
fn = wrap_gradio_call(save_training_setting),
inputs=[
save_file_name,
hypernetwork_learn_rate,
batch_size,
gradient_step,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
template_file,
use_beta_scheduler,
beta_repeat_epoch,
epoch_mult,
warmup,
min_lr,
gamma_rate,
use_beta_adamW_checkbox,
save_converge_opt,
generate_converge_opt,
adamw_weight_decay,
adamw_beta_1,
adamw_beta_2,
adamw_eps,
show_gradient_clip_checkbox,
gradient_clip_opt,
optional_gradient_clip_value,
optional_gradient_norm_type],
outputs=[
]
)
train_embedding.click( train_embedding.click(
fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]), fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion", _js="start_training_textual_inversion",
@ -229,7 +315,8 @@ def on_train_gamma_tab(params=None):
show_gradient_clip_checkbox, show_gradient_clip_checkbox,
gradient_clip_opt, gradient_clip_opt,
optional_gradient_clip_value, optional_gradient_clip_value,
optional_gradient_norm_type optional_gradient_norm_type,
load_training_option
], ],
outputs=[ outputs=[
@ -244,3 +331,64 @@ def on_train_gamma_tab(params=None):
outputs=[], outputs=[],
) )
return [(train_gamma, "Train Gamma", "train_gamma")] return [(train_gamma, "Train Gamma", "train_gamma")]
def on_train_tuning(params=None):
dummy_component = gr.Label(visible=False)
with gr.Tab(label="Train_Tuning") as train_tuning:
gr.HTML(
value="<p style='margin-bottom: 0.7em'>Train Hypernetwork; you must specify a directory <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
"refresh_train_hypernetwork_name")
optional_new_hypernetwork_name = gr.Textbox(label="Hypernetwork name to create, leave it empty to use selected", value="")
load_hypernetworks_option = gr.Textbox(
label="Load Hypernetwork creation option from saved json file. filename cannot have ',' inside, and files should be splitted by ','.", value="")
load_training_options = gr.Textbox(
label="Load training option(s) from saved json file. filename cannot have ',' inside, and files should be splitted by ','.", value="")
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", value=True)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
value="textual_inversion")
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable',
value=500, precision=0)
save_model_every = gr.Number(
label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0)
preview_from_txt2img = gr.Checkbox(
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
ti_output = gr.Text(elem_id="ti_output4", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error4", value="")
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(train_hypernetwork_ui_tuning, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
train_hypernetwork_name,
dataset_directory,
log_directory,
create_image_every,
save_model_every,
preview_from_txt2img,
*params.txt2img_preview_params,
move_optim_when_generate,
optional_new_hypernetwork_name,
load_hypernetworks_option,
load_training_options
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
return [(train_tuning, "Train Tuning", "train_tuning")]

View File

@ -2,9 +2,44 @@ import html
import os import os
from modules import shared, sd_hijack, devices from modules import shared, sd_hijack, devices
from .hypernetwork import Hypernetwork, train_hypernetwork from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
import gradio as gr import gradio as gr
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
weight_init_seed=None, normal_std=0.01):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
assert name, "Name cannot be empty!"
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
if dropout_structure and type(dropout_structure) == str:
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
normal_std = float(normal_std)
assert normal_std > 0, "Normal Standard Deviation should be bigger than 0!"
hypernet = Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure),
optional_info=optional_info,
generation_seed=weight_init_seed if weight_init_seed != -1 else None,
normal_std=normal_std
)
hypernet.save(fn)
load_hypernetwork(name)
return hypernet
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None, def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
weight_init_seed=None, normal_std=0.01): weight_init_seed=None, normal_std=0.01):

View File

@ -1,5 +1,6 @@
import os import os
from modules.call_queue import wrap_gradio_call
from modules.hypernetworks.ui import keys from modules.hypernetworks.ui import keys
import modules.scripts as scripts import modules.scripts as scripts
from modules import script_callbacks, shared, sd_hijack from modules import script_callbacks, shared, sd_hijack
@ -108,8 +109,8 @@ def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
def create_extension_tab(params=None): def create_extension_tab(params=None):
with gr.Tab(label="Create Beta hypernetwork") as create_beta: with gr.Tab(label="Create Beta hypernetwork") as create_beta:
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1024", "1280"],
choices=["768", "320", "640", "1280"]) choices=["768", "320", "640", "1024", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure",
placeholder="1st and last digit must be 1. ex:'1, 2, 1'") placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", new_hypernetwork_activation_func = gr.Dropdown(value="linear",
@ -143,8 +144,27 @@ def create_extension_tab(params=None):
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
save_setting = gr.Button(value="Save hypernetwork setting to file")
setting_name = gr.Textbox(label="Setting file name", value="")
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False) ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error2", value="") ti_outcome = gr.HTML(elem_id="ti_error2", value="")
save_setting.click(
fn=wrap_gradio_call(external_patch_ui.save_hypernetwork_setting),
inputs=[
new_hypernetwork_sizes,
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout,
new_hypernetwork_dropout_structure,
optional_info,
generation_seed if generation_seed.visible else None,
normal_std if normal_std.visible else 0.01],
outputs=[]
)
create_hypernetwork.click( create_hypernetwork.click(
fn=ui.create_hypernetwork, fn=ui.create_hypernetwork,
inputs=[ inputs=[