maybe
parent
7631f51b42
commit
dceaa667c2
|
|
@ -1,12 +1,11 @@
|
||||||
import csv
|
|
||||||
import datetime
|
import datetime
|
||||||
import gc
|
import gc
|
||||||
import glob
|
|
||||||
import html
|
import html
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -21,9 +20,21 @@ from .textual_inversion import validate_train_inputs, write_loss
|
||||||
from ..hypernetwork import Hypernetwork, load_hypernetwork
|
from ..hypernetwork import Hypernetwork, load_hypernetwork
|
||||||
from . import sd_hijack_checkpoint
|
from . import sd_hijack_checkpoint
|
||||||
from ..hnutil import optim_to
|
from ..hnutil import optim_to
|
||||||
|
from ..ui import create_hypernetwork_load
|
||||||
from ..scheduler import CosineAnnealingWarmUpRestarts
|
from ..scheduler import CosineAnnealingWarmUpRestarts
|
||||||
from .dataset import PersonalizedBase,PersonalizedDataLoader
|
from .dataset import PersonalizedBase,PersonalizedDataLoader
|
||||||
|
|
||||||
|
def get_training_option(filename):
|
||||||
|
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename)
|
||||||
|
elif os.path.exists(filename):
|
||||||
|
filename = filename
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
print(f"Loading setting from {filename}!")
|
||||||
|
obj = json.load(filename)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
||||||
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
||||||
|
|
@ -33,10 +44,40 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False,
|
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False,
|
||||||
move_optimizer=True,
|
move_optimizer=True,
|
||||||
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,adamw_eps=1e-8,
|
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,adamw_eps=1e-8,
|
||||||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2):
|
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2,
|
||||||
|
load_training_options = ''):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
if load_training_options != '':
|
||||||
|
dump: dict = get_training_option(load_training_options)
|
||||||
|
if dump and dump is not None:
|
||||||
|
learn_rate = dump['hypernetwork_learn_rate']
|
||||||
|
batch_size = dump['batch_size']
|
||||||
|
gradient_step = dump['gradient_step']
|
||||||
|
training_width = dump['training_width']
|
||||||
|
training_height = dump['training_height']
|
||||||
|
steps = dump['steps']
|
||||||
|
shuffle_tags = dump['shuffle_tags']
|
||||||
|
tag_drop_out = dump['tag_drop_out']
|
||||||
|
save_when_converge = dump['save_when_converge']
|
||||||
|
create_when_converge = dump['create_when_converge']
|
||||||
|
latent_sampling_method = dump['latent_sampling_method']
|
||||||
|
template_file = dump['template_file']
|
||||||
|
use_beta_scheduler = dump['use_beta_scheduler']
|
||||||
|
beta_repeat_epoch = dump['beta_repeat_epoch']
|
||||||
|
epoch_mult = dump['epoch_mult']
|
||||||
|
warmup = dump['warmup']
|
||||||
|
min_lr = dump['min_lr']
|
||||||
|
gamma_rate = dump['gamma_rate']
|
||||||
|
use_adamw_parameter = dump['use_beta_adamW_checkbox']
|
||||||
|
adamw_weight_decay = dump['adamw_weight_decay']
|
||||||
|
adamw_beta_1 = dump['adamw_beta_1']
|
||||||
|
adamw_beta_2 = dump['adamw_beta_2']
|
||||||
|
adamw_eps = dump['adamw_eps']
|
||||||
|
use_grad_opts = dump['show_gradient_clip_checkbox']
|
||||||
|
gradient_clip_opt = dump['gradient_clip_opt']
|
||||||
|
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||||
|
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||||
try:
|
try:
|
||||||
if use_adamw_parameter:
|
if use_adamw_parameter:
|
||||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
|
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
|
||||||
|
|
@ -410,3 +451,458 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
||||||
|
def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
|
create_image_every, save_hypernetwork_every,
|
||||||
|
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||||
|
move_optimizer=True,
|
||||||
|
load_hypernetworks_option='', load_training_options=''):
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
from modules import images
|
||||||
|
if load_hypernetworks_option != '':
|
||||||
|
timeStr = time.asctime()[:19]
|
||||||
|
dump_hyper: dict = get_training_option(load_hypernetworks_option)
|
||||||
|
hypernetwork_name = hypernetwork_name + timeStr
|
||||||
|
enable_sizes = dump_hyper['enable_sizes']
|
||||||
|
overwrite_old = dump_hyper['overwrite_old']
|
||||||
|
layer_structure = dump_hyper['layer_structure']
|
||||||
|
activation_func = dump_hyper['activation_func']
|
||||||
|
weight_init = dump_hyper['weight_init']
|
||||||
|
add_layer_norm = dump_hyper['add_layer_norm']
|
||||||
|
use_dropout = dump_hyper['use_dropout']
|
||||||
|
dropout_structure = dump_hyper['dropout_structure']
|
||||||
|
optional_info = dump_hyper['optional_info']
|
||||||
|
weight_init_seed = dump_hyper['weight_init_seed']
|
||||||
|
normal_std = dump_hyper['normal_std']
|
||||||
|
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std)
|
||||||
|
else:
|
||||||
|
load_hypernetwork(hypernetwork_name)
|
||||||
|
hypernetwork_name = hypernetwork_name + time.asctime()[:19]
|
||||||
|
shared.loaded_hypernetwork.save(hypernetwork_name)
|
||||||
|
load_hypernetwork(hypernetwork_name)
|
||||||
|
if load_training_options != '':
|
||||||
|
dump: dict = get_training_option(load_training_options)
|
||||||
|
if dump and dump is not None:
|
||||||
|
learn_rate = dump['hypernetwork_learn_rate']
|
||||||
|
batch_size = dump['batch_size']
|
||||||
|
gradient_step = dump['gradient_step']
|
||||||
|
training_width = dump['training_width']
|
||||||
|
training_height = dump['training_height']
|
||||||
|
steps = dump['steps']
|
||||||
|
shuffle_tags = dump['shuffle_tags']
|
||||||
|
tag_drop_out = dump['tag_drop_out']
|
||||||
|
save_when_converge = dump['save_when_converge']
|
||||||
|
create_when_converge = dump['create_when_converge']
|
||||||
|
latent_sampling_method = dump['latent_sampling_method']
|
||||||
|
template_file = dump['template_file']
|
||||||
|
use_beta_scheduler = dump['use_beta_scheduler']
|
||||||
|
beta_repeat_epoch = dump['beta_repeat_epoch']
|
||||||
|
epoch_mult = dump['epoch_mult']
|
||||||
|
warmup = dump['warmup']
|
||||||
|
min_lr = dump['min_lr']
|
||||||
|
gamma_rate = dump['gamma_rate']
|
||||||
|
use_adamw_parameter = dump['use_beta_adamW_checkbox']
|
||||||
|
adamw_weight_decay = dump['adamw_weight_decay']
|
||||||
|
adamw_beta_1 = dump['adamw_beta_1']
|
||||||
|
adamw_beta_2 = dump['adamw_beta_2']
|
||||||
|
adamw_eps = dump['adamw_eps']
|
||||||
|
use_grad_opts = dump['show_gradient_clip_checkbox']
|
||||||
|
gradient_clip_opt = dump['gradient_clip_opt']
|
||||||
|
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||||
|
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Cannot load from {load_training_options}!")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Cannot load from {load_training_options}!")
|
||||||
|
try:
|
||||||
|
if use_adamw_parameter:
|
||||||
|
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
|
||||||
|
assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!"
|
||||||
|
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, adamw_eps])), "Cannot use negative or >1 number for adamW parameters!"
|
||||||
|
adamW_kwarg_dict = {
|
||||||
|
'weight_decay' : adamw_weight_decay,
|
||||||
|
'betas' : (adamw_beta_1, adamw_beta_2),
|
||||||
|
'eps' : adamw_eps
|
||||||
|
}
|
||||||
|
print('Using custom AdamW parameters')
|
||||||
|
else:
|
||||||
|
adamW_kwarg_dict = {
|
||||||
|
'weight_decay' : 0.01,
|
||||||
|
'betas' : (0.9, 0.99),
|
||||||
|
'eps' : 1e-8
|
||||||
|
}
|
||||||
|
if use_beta_scheduler:
|
||||||
|
print("Using Beta Scheduler")
|
||||||
|
beta_repeat_epoch = int(beta_repeat_epoch)
|
||||||
|
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
|
||||||
|
min_lr = float(min_lr)
|
||||||
|
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
|
||||||
|
gamma_rate = float(gamma_rate)
|
||||||
|
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
|
||||||
|
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
|
||||||
|
epoch_mult = float(epoch_mult)
|
||||||
|
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
|
||||||
|
warmup = int(warmup)
|
||||||
|
assert warmup >= 1, "Warmup epoch should be larger than 0!"
|
||||||
|
print(f"Save when converges : {save_when_converge}")
|
||||||
|
print(f"Generate image when converges : {create_when_converge}")
|
||||||
|
else:
|
||||||
|
beta_repeat_epoch = 4000
|
||||||
|
epoch_mult=1
|
||||||
|
warmup=10
|
||||||
|
min_lr=1e-7
|
||||||
|
gamma_rate=1
|
||||||
|
save_when_converge = False
|
||||||
|
create_when_converge = False
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
||||||
|
if use_grad_opts and gradient_clip_opt != "None":
|
||||||
|
try:
|
||||||
|
optional_gradient_clip_value = float(optional_gradient_clip_value)
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})")
|
||||||
|
if gradient_clip_opt == "Norm":
|
||||||
|
try:
|
||||||
|
grad_norm = int(optional_gradient_norm_type)
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})")
|
||||||
|
assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}"
|
||||||
|
print(f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}")
|
||||||
|
def gradient_clipping(arg1):
|
||||||
|
torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}")
|
||||||
|
def gradient_clipping(arg1):
|
||||||
|
torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
def gradient_clipping(arg1):
|
||||||
|
return
|
||||||
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
|
create_image_every = create_image_every or 0
|
||||||
|
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||||
|
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||||
|
log_directory, name="hypernetwork")
|
||||||
|
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||||
|
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
||||||
|
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||||
|
|
||||||
|
shared.state.job = "train-hypernetwork"
|
||||||
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
|
if save_hypernetwork_every > 0 or save_when_converge:
|
||||||
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
|
os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
hypernetwork_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0 or create_when_converge:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
initial_step = hypernetwork.step or 0
|
||||||
|
if initial_step >= steps:
|
||||||
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
print("Tensorboard logging enabled")
|
||||||
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
else:
|
||||||
|
tensorboard_writer = None
|
||||||
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
detach_grad = shared.opts.disable_ema # test code that removes EMA
|
||||||
|
if detach_grad:
|
||||||
|
print("Disabling training for staged models!")
|
||||||
|
shared.sd_model.cond_stage_model.requires_grad_(False)
|
||||||
|
shared.sd_model.first_stage_model.requires_grad_(False)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||||
|
height=training_height,
|
||||||
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||||
|
placeholder_token=hypernetwork_name, model=shared.sd_model,
|
||||||
|
cond_model=shared.sd_model.cond_stage_model,
|
||||||
|
device=devices.device, template_file=template_file,
|
||||||
|
include_cond=True, batch_size=batch_size,
|
||||||
|
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
||||||
|
tag_drop_out=tag_drop_out,
|
||||||
|
latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
|
||||||
|
batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.parallel_processing_allowed = False
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
|
weights = hypernetwork.weights(True)
|
||||||
|
|
||||||
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
|
if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
|
if use_adamw_parameter:
|
||||||
|
if hypernetwork.optimizer_name != 'AdamW':
|
||||||
|
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
|
||||||
|
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||||
|
else:
|
||||||
|
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||||
|
optimizer_name = hypernetwork.optimizer_name
|
||||||
|
else:
|
||||||
|
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||||
|
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||||
|
optimizer_name = 'AdamW'
|
||||||
|
|
||||||
|
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||||
|
try:
|
||||||
|
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||||
|
optim_to(optimizer, devices.device)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print("Cannot resume from saved optimizer!")
|
||||||
|
print(e)
|
||||||
|
if use_beta_scheduler:
|
||||||
|
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
|
||||||
|
scheduler_beta.last_epoch =hypernetwork.step-1
|
||||||
|
else:
|
||||||
|
scheduler_beta = None
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
pg['lr'] = scheduler.learn_rate
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
batch_size = ds.batch_size
|
||||||
|
gradient_step = ds.gradient_step
|
||||||
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
|
loss_step = 0
|
||||||
|
_loss_step = 0 # internal
|
||||||
|
# size = len(ds.indexes)
|
||||||
|
loss_dict = defaultdict(lambda: deque(maxlen=1024))
|
||||||
|
# losses = torch.zeros((size,))
|
||||||
|
# previous_mean_losses = [0]
|
||||||
|
# previous_mean_loss = 0
|
||||||
|
# print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
|
steps_without_grad = 0
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
|
try:
|
||||||
|
for i in range((steps - initial_step) * gradient_step):
|
||||||
|
if scheduler.finished or hypernetwork.step > steps:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
for j, batch in enumerate(dl):
|
||||||
|
# works as a drop_last=True for gradient accumulation
|
||||||
|
if j == max_steps_per_epoch:
|
||||||
|
break
|
||||||
|
if use_beta_scheduler:
|
||||||
|
scheduler_beta.step(hypernetwork.step)
|
||||||
|
else:
|
||||||
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
|
||||||
|
non_blocking=pin_memory)
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
else:
|
||||||
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
|
loss = shared.sd_model(x, c)[0]
|
||||||
|
for filenames in batch.filename:
|
||||||
|
loss_dict[filenames].append(loss.item())
|
||||||
|
loss /= gradient_step
|
||||||
|
del x
|
||||||
|
del c
|
||||||
|
|
||||||
|
_loss_step += loss.item()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
batch.latent_sample.to(devices.cpu)
|
||||||
|
# go back until we reach gradient accumulation steps
|
||||||
|
if (j + 1) % gradient_step != 0:
|
||||||
|
continue
|
||||||
|
gradient_clipping(weights)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||||
|
# scaler.unscale_(optimizer)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
try:
|
||||||
|
scaler.step(optimizer)
|
||||||
|
except AssertionError:
|
||||||
|
optimizer.param_groups[0]['capturable'] = True
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
hypernetwork.step += 1
|
||||||
|
pbar.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss_step = _loss_step
|
||||||
|
_loss_step = 0
|
||||||
|
|
||||||
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||||
|
pbar.set_description(description)
|
||||||
|
if hypernetwork_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or (save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)):
|
||||||
|
# Before saving, change name to match current checkpoint.
|
||||||
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||||
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
|
||||||
|
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
|
||||||
|
{
|
||||||
|
"loss": f"{loss_step:.7f}",
|
||||||
|
"learn_rate": optimizer.param_groups[0]['lr']
|
||||||
|
})
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
epoch_num = hypernetwork.step // len(ds)
|
||||||
|
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||||
|
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
|
||||||
|
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||||
|
if images_dir is not None and (use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (create_image_every > 0 and steps_done % create_image_every == 0):
|
||||||
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
cuda_rng_state = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||||
|
hypernetwork.eval()
|
||||||
|
if move_optimizer:
|
||||||
|
optim_to(optimizer, devices.cpu)
|
||||||
|
gc.collect()
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = batch.cond_text[0]
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
|
preview_text = p.prompt
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
|
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||||
|
hypernetwork.train()
|
||||||
|
if move_optimizer:
|
||||||
|
optim_to(optimizer, devices.device)
|
||||||
|
if image is not None:
|
||||||
|
shared.state.assign_current_image(image)
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||||
|
shared.opts.samples_format,
|
||||||
|
processed.infotexts[0], p=p,
|
||||||
|
forced_filename=forced_filename,
|
||||||
|
save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {loss_step:.7f}<br/>
|
||||||
|
Step: {steps_done}<br/>
|
||||||
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
except Exception:
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
pbar.leave = False
|
||||||
|
pbar.close()
|
||||||
|
hypernetwork.eval()
|
||||||
|
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||||
|
report_statistics(loss_dict)
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||||
|
del optimizer
|
||||||
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directory,
|
||||||
|
create_image_every, save_hypernetwork_every, preview_from_txt2img, preview_prompt,
|
||||||
|
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
|
||||||
|
preview_width, preview_height,
|
||||||
|
move_optimizer=True,
|
||||||
|
optional_new_hypernetwork_name='', load_hypernetworks_options='', load_training_options=''):
|
||||||
|
load_hypernetworks_options = load_hypernetworks_options.split(',')
|
||||||
|
load_training_options = load_training_options.split(',')
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
for load_hypernetworks_option in load_hypernetworks_options:
|
||||||
|
for load_training_option in load_training_options:
|
||||||
|
internal_clean_training(hypernetwork_name if load_hypernetworks_option != '' else optional_new_hypernetwork_name, data_root, log_directory,
|
||||||
|
create_image_every, save_hypernetwork_every,
|
||||||
|
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||||
|
move_optimizer,
|
||||||
|
load_hypernetworks_option, load_training_option)
|
||||||
|
if shared.state.interrupted:
|
||||||
|
return
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
import html
|
import html
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
from modules import shared, sd_hijack, devices
|
from modules import shared, sd_hijack, devices
|
||||||
|
from modules.call_queue import wrap_gradio_call
|
||||||
|
from modules.hypernetworks.ui import keys
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui import create_refresh_button, gr_show
|
from modules.ui import create_refresh_button, gr_show
|
||||||
from webui import wrap_gradio_gpu_call
|
from webui import wrap_gradio_gpu_call
|
||||||
from .textual_inversion import train_embedding as train_embedding_external
|
from .textual_inversion import train_embedding as train_embedding_external
|
||||||
from .hypernetwork import train_hypernetwork as train_hypernetwork_external
|
from .hypernetwork import train_hypernetwork as train_hypernetwork_external, train_hypernetwork_tuning
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,6 +39,51 @@ Hypernetwork saved to {html.escape(filename)}
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork_ui_tuning(*args):
|
||||||
|
|
||||||
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
|
||||||
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
hypernetwork, filename = train_hypernetwork_tuning(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||||
|
Hypernetwork saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
|
||||||
|
def save_training_setting(*args):
|
||||||
|
save_file_name, hypernetwork_learn_rate, batch_size, gradient_step, training_width, \
|
||||||
|
training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, \
|
||||||
|
template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \
|
||||||
|
gamma_rate, use_beta_adamW_checkbox, save_converge_opt, generate_converge_opt, \
|
||||||
|
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
|
||||||
|
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type = args
|
||||||
|
|
||||||
|
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
|
||||||
|
with open(filename, 'w') as file:
|
||||||
|
json.dump(locals(), file)
|
||||||
|
print(f"File saved as {filename}")
|
||||||
|
|
||||||
|
def save_hypernetwork_setting(*args):
|
||||||
|
save_file_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std = args
|
||||||
|
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_hypernetwork_' + '.json'
|
||||||
|
with open(filename, 'w') as file:
|
||||||
|
json.dump(locals(), file)
|
||||||
|
print(f"File saved as {filename}")
|
||||||
|
|
||||||
def on_train_gamma_tab(params=None):
|
def on_train_gamma_tab(params=None):
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
with gr.Tab(label="Train_Gamma") as train_gamma:
|
with gr.Tab(label="Train_Gamma") as train_gamma:
|
||||||
|
|
@ -140,8 +189,45 @@ def on_train_gamma_tab(params=None):
|
||||||
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||||
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error3", value="")
|
ti_outcome = gr.HTML(elem_id="ti_error3", value="")
|
||||||
|
save_training_option = gr.Button(label="Save training setting")
|
||||||
|
save_file_name = gr.Textbox(label="File name to save setting as", value="")
|
||||||
|
load_training_option = gr.Textbox(label="Load training option from saved json file. This will override settings above", value="")
|
||||||
|
#Full path to .json or simple names are recommended.
|
||||||
|
save_training_option.click(
|
||||||
|
fn = wrap_gradio_call(save_training_setting),
|
||||||
|
inputs=[
|
||||||
|
save_file_name,
|
||||||
|
hypernetwork_learn_rate,
|
||||||
|
batch_size,
|
||||||
|
gradient_step,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
|
steps,
|
||||||
|
shuffle_tags,
|
||||||
|
tag_drop_out,
|
||||||
|
latent_sampling_method,
|
||||||
|
template_file,
|
||||||
|
use_beta_scheduler,
|
||||||
|
beta_repeat_epoch,
|
||||||
|
epoch_mult,
|
||||||
|
warmup,
|
||||||
|
min_lr,
|
||||||
|
gamma_rate,
|
||||||
|
use_beta_adamW_checkbox,
|
||||||
|
save_converge_opt,
|
||||||
|
generate_converge_opt,
|
||||||
|
adamw_weight_decay,
|
||||||
|
adamw_beta_1,
|
||||||
|
adamw_beta_2,
|
||||||
|
adamw_eps,
|
||||||
|
show_gradient_clip_checkbox,
|
||||||
|
gradient_clip_opt,
|
||||||
|
optional_gradient_clip_value,
|
||||||
|
optional_gradient_norm_type],
|
||||||
|
outputs=[
|
||||||
|
|
||||||
|
]
|
||||||
|
)
|
||||||
train_embedding.click(
|
train_embedding.click(
|
||||||
fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
|
|
@ -229,7 +315,8 @@ def on_train_gamma_tab(params=None):
|
||||||
show_gradient_clip_checkbox,
|
show_gradient_clip_checkbox,
|
||||||
gradient_clip_opt,
|
gradient_clip_opt,
|
||||||
optional_gradient_clip_value,
|
optional_gradient_clip_value,
|
||||||
optional_gradient_norm_type
|
optional_gradient_norm_type,
|
||||||
|
load_training_option
|
||||||
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|
@ -244,3 +331,64 @@ def on_train_gamma_tab(params=None):
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
return [(train_gamma, "Train Gamma", "train_gamma")]
|
return [(train_gamma, "Train Gamma", "train_gamma")]
|
||||||
|
|
||||||
|
def on_train_tuning(params=None):
|
||||||
|
dummy_component = gr.Label(visible=False)
|
||||||
|
with gr.Tab(label="Train_Tuning") as train_tuning:
|
||||||
|
gr.HTML(
|
||||||
|
value="<p style='margin-bottom: 0.7em'>Train Hypernetwork; you must specify a directory <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
|
with gr.Row():
|
||||||
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
|
||||||
|
choices=[x for x in shared.hypernetworks.keys()])
|
||||||
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
||||||
|
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
||||||
|
"refresh_train_hypernetwork_name")
|
||||||
|
optional_new_hypernetwork_name = gr.Textbox(label="Hypernetwork name to create, leave it empty to use selected", value="")
|
||||||
|
load_hypernetworks_option = gr.Textbox(
|
||||||
|
label="Load Hypernetwork creation option from saved json file. filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||||
|
load_training_options = gr.Textbox(
|
||||||
|
label="Load training option(s) from saved json file. filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||||
|
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", value=True)
|
||||||
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
||||||
|
value="textual_inversion")
|
||||||
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable',
|
||||||
|
value=500, precision=0)
|
||||||
|
save_model_every = gr.Number(
|
||||||
|
label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
preview_from_txt2img = gr.Checkbox(
|
||||||
|
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||||
|
with gr.Row():
|
||||||
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||||
|
ti_output = gr.Text(elem_id="ti_output4", value="", show_label=False)
|
||||||
|
ti_outcome = gr.HTML(elem_id="ti_error4", value="")
|
||||||
|
train_hypernetwork.click(
|
||||||
|
fn=wrap_gradio_gpu_call(train_hypernetwork_ui_tuning, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
|
train_hypernetwork_name,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
create_image_every,
|
||||||
|
save_model_every,
|
||||||
|
preview_from_txt2img,
|
||||||
|
*params.txt2img_preview_params,
|
||||||
|
move_optim_when_generate,
|
||||||
|
optional_new_hypernetwork_name,
|
||||||
|
load_hypernetworks_option,
|
||||||
|
load_training_options
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt_training.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
return [(train_tuning, "Train Tuning", "train_tuning")]
|
||||||
|
|
@ -2,9 +2,44 @@ import html
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules import shared, sd_hijack, devices
|
from modules import shared, sd_hijack, devices
|
||||||
from .hypernetwork import Hypernetwork, train_hypernetwork
|
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||||
|
weight_init_seed=None, normal_std=0.01):
|
||||||
|
# Remove illegal characters from name.
|
||||||
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
assert name, "Name cannot be empty!"
|
||||||
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
if type(layer_structure) == str:
|
||||||
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
|
||||||
|
if dropout_structure and type(dropout_structure) == str:
|
||||||
|
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||||
|
normal_std = float(normal_std)
|
||||||
|
assert normal_std > 0, "Normal Standard Deviation should be bigger than 0!"
|
||||||
|
hypernet = Hypernetwork(
|
||||||
|
name=name,
|
||||||
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
|
layer_structure=layer_structure,
|
||||||
|
activation_func=activation_func,
|
||||||
|
weight_init=weight_init,
|
||||||
|
add_layer_norm=add_layer_norm,
|
||||||
|
use_dropout=use_dropout,
|
||||||
|
dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure),
|
||||||
|
optional_info=optional_info,
|
||||||
|
generation_seed=weight_init_seed if weight_init_seed != -1 else None,
|
||||||
|
normal_std=normal_std
|
||||||
|
)
|
||||||
|
hypernet.save(fn)
|
||||||
|
|
||||||
|
load_hypernetwork(name)
|
||||||
|
|
||||||
|
return hypernet
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||||
weight_init_seed=None, normal_std=0.01):
|
weight_init_seed=None, normal_std=0.01):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from modules.call_queue import wrap_gradio_call
|
||||||
from modules.hypernetworks.ui import keys
|
from modules.hypernetworks.ui import keys
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
from modules import script_callbacks, shared, sd_hijack
|
from modules import script_callbacks, shared, sd_hijack
|
||||||
|
|
@ -108,8 +109,8 @@ def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
||||||
def create_extension_tab(params=None):
|
def create_extension_tab(params=None):
|
||||||
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"],
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1024", "1280"],
|
||||||
choices=["768", "320", "640", "1280"])
|
choices=["768", "320", "640", "1024", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure",
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure",
|
||||||
placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear",
|
new_hypernetwork_activation_func = gr.Dropdown(value="linear",
|
||||||
|
|
@ -143,8 +144,27 @@ def create_extension_tab(params=None):
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
||||||
|
save_setting = gr.Button(value="Save hypernetwork setting to file")
|
||||||
|
setting_name = gr.Textbox(label="Setting file name", value="")
|
||||||
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
||||||
|
|
||||||
|
save_setting.click(
|
||||||
|
fn=wrap_gradio_call(external_patch_ui.save_hypernetwork_setting),
|
||||||
|
inputs=[
|
||||||
|
new_hypernetwork_sizes,
|
||||||
|
overwrite_old_hypernetwork,
|
||||||
|
new_hypernetwork_layer_structure,
|
||||||
|
new_hypernetwork_activation_func,
|
||||||
|
new_hypernetwork_initialization_option,
|
||||||
|
new_hypernetwork_add_layer_norm,
|
||||||
|
new_hypernetwork_use_dropout,
|
||||||
|
new_hypernetwork_dropout_structure,
|
||||||
|
optional_info,
|
||||||
|
generation_seed if generation_seed.visible else None,
|
||||||
|
normal_std if normal_std.visible else 0.01],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
create_hypernetwork.click(
|
create_hypernetwork.click(
|
||||||
fn=ui.create_hypernetwork,
|
fn=ui.create_hypernetwork,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue