Merge pull request #37 from aria1th/latest-fix

Latest fix
beta-dadaptation
AngelBottomless 2023-01-21 21:55:46 +09:00 committed by GitHub
commit 3dbb5abae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 458 deletions

View File

@ -39,8 +39,12 @@ training_scheduler = Scheduler(cycle_step=-1, repeat=False)
def get_current(value, step=None):
if step is None:
if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None:
return training_scheduler(value, shared.loaded_hypernetwork.step)
if hasattr(shared, 'accessible_hypernetwork'):
hypernetwork = shared.accessible_hypernetwork
else:
return value
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
return training_scheduler(value, hypernetwork.step)
return value
return max(1, training_scheduler(value, step))

View File

@ -25,6 +25,18 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
from ..ddpm_hijack import set_scheduler
def set_accessible(obj):
setattr(shared, 'accessible_hypernetwork', obj)
if hasattr(shared, 'loaded_hypernetworks'):
shared.loaded_hypernetworks.clear()
shared.loaded_hypernetworks = [obj,]
def remove_accessible():
delattr(shared, 'accessible_hypernetwork')
if hasattr(shared, 'loaded_hypernetworks'):
shared.loaded_hypernetworks.clear()
def get_training_option(filename):
print(filename)
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
@ -43,6 +55,40 @@ def get_training_option(filename):
return obj
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1,use_adamw_parameter=False, **adamW_kwarg_dict):
""" returns hypernetwork object binded with optimizer"""
hypernetwork = load_hypernetwork(hypernetwork_name)
hypernetwork.to(devices.device)
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
set_accessible(hypernetwork)
weights = hypernetwork.weights(True)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
if use_adamw_parameter:
if hypernetwork.optimizer_name != 'AdamW':
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
else:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
return hypernetwork, optimizer, weights, optimizer_name
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
@ -176,15 +222,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0)
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, **adamW_kwarg_dict)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
@ -204,7 +246,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
@ -251,29 +292,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
if use_adamw_parameter:
if hypernetwork.optimizer_name != 'AdamW':
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
else:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
if use_beta_scheduler:
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
@ -421,7 +439,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.prompt = preview_prompt + hypernetwork.extra_name()
print(p.prompt)
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
@ -430,7 +449,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
p.steps = 20
p.width = training_width
p.height = training_height
@ -465,6 +484,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
set_accessible(hypernetwork)
shared.state.job_no = hypernetwork.step
@ -487,6 +507,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
set_scheduler(-1, False, False)
remove_accessible()
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
@ -536,11 +557,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
dropout_structure, optional_info, weight_init_seed, normal_std,
skip_connection)
else:
load_hypernetwork(hypernetwork_name)
hypernetwork = load_hypernetwork(hypernetwork_name)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
shared.reload_hypernetworks()
load_hypernetwork(hypernetwork_name)
hypernetwork = load_hypernetwork(hypernetwork_name)
if load_training_options != '':
dump: dict = get_training_option(load_training_options)
if dump and dump is not None:
@ -660,11 +681,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
hypernetwork.to(devices.device)
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
set_accessible(hypernetwork)
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
@ -687,7 +708,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
@ -711,7 +731,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
shared.sd_model.first_stage_model.requires_grad_(False)
torch.cuda.empty_cache()
pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
@ -755,10 +774,10 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
optim_to(optimizer, devices.device)
if use_beta_scheduler:
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
@ -895,7 +914,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
hypernetwork.eval()
if move_optimizer:
optim_to(optimizer, devices.cpu)
gc.collect()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
@ -906,7 +924,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.prompt = preview_prompt + hypernetwork.extra_name()
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
@ -915,7 +933,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
p.steps = 20
p.width = training_width
p.height = training_height
@ -950,6 +968,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
set_accessible(hypernetwork)
shared.state.job_no = hypernetwork.step
@ -970,6 +989,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
hypernetwork.eval()
set_scheduler(-1, False, False)
shared.parallel_processing_allowed = old_parallel_processing_allowed
remove_accessible()
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
if shared.opts.training_enable_tensorboard:

View File

@ -5,7 +5,6 @@ import random
from modules import shared, sd_hijack, devices
from modules.call_queue import wrap_gradio_call
from modules.hypernetworks.ui import keys
from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show
from webui import wrap_gradio_gpu_call
@ -15,8 +14,11 @@ import gradio as gr
def train_hypernetwork_ui(*args):
initial_hypernetwork = shared.loaded_hypernetwork
initial_hypernetwork = None
if hasattr(shared, 'loaded_hypernetwork'):
initial_hypernetwork = shared.loaded_hypernetwork
else:
shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
@ -32,14 +34,21 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = initial_hypernetwork
else:
shared.loaded_hypernetworks = []
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
def train_hypernetwork_ui_tuning(*args):
initial_hypernetwork = shared.loaded_hypernetwork
initial_hypernetwork = None
if hasattr(shared, 'loaded_hypernetwork'):
initial_hypernetwork = shared.loaded_hypernetwork
else:
shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@ -55,7 +64,10 @@ Training {'interrupted' if shared.state.interrupted else 'finished'}.
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = initial_hypernetwork
else:
shared.loaded_hypernetworks = []
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()

View File

@ -1,5 +1,14 @@
import torch
import modules.shared
def find_self(self):
for k, v in modules.shared.hypernetworks.items():
if v == self:
return k
return None
def optim_to(optim:torch.optim.Optimizer, device="cpu"):
def inplace_move(obj: torch.Tensor, target):

View File

@ -1,15 +1,10 @@
import datetime
import glob
import html
import inspect
import os
import sys
import traceback
from collections import defaultdict, deque
from statistics import stdev, mean
import torch
import tqdm
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
import scripts.xy_grid
@ -20,16 +15,10 @@ except (ImportError, ModuleNotFoundError):
print("modules.hashes is not found, will use backup module from extension!")
from .hashes_backup import sha256
from .scheduler import CosineAnnealingWarmUpRestarts
import modules.hypernetworks.hypernetwork
from modules import devices, shared, sd_models, processing, sd_samplers, generation_parameters_copypaste
from .hnutil import parse_dropout_structure, optim_to
from modules.hypernetworks.hypernetwork import report_statistics, save_hypernetwork, stack_conds, optimizer_dict
from modules.textual_inversion import textual_inversion
from .dataset import PersonalizedBase
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules import devices, shared, sd_models, processing, generation_parameters_copypaste
from .hnutil import parse_dropout_structure, find_self
from .shared import version_flag
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
w, b = layer.weight.data, layer.bias.data
@ -217,10 +206,10 @@ class HypernetworkModule(torch.nn.Module):
resnet_result = self.linear(x)
residual = resnet_result - x
if multiplier is None or not isinstance(multiplier, (int, float)):
multiplier = HypernetworkModule.multiplier
return x + multiplier * residual # interpolate
multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
return x + multiplier * residual # interpolate
if multiplier is None or not isinstance(multiplier, (int, float)):
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
return x + self.linear(x) * multiplier
def trainables(self, train=False):
@ -317,6 +306,14 @@ class Hypernetwork:
sha256v = sha256(self.filename, f'hypernet/{self.name}')
return sha256v[0:10]
def extra_name(self):
if version_flag:
return ""
found = find_self(self)
if found is not None:
return f" <hypernet:{found}:1.0>"
return f" <hypernet:{self.name}:1.0>"
def save(self, filename):
state_dict = {}
optimizer_saved_dict = {}
@ -412,9 +409,18 @@ class Hypernetwork:
self.eval()
def to(self, device):
for values in self.layers.values():
values[0].to(device)
values[1].to(device)
for k, layers in self.layers.items():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
for k, layers in self.layers.items():
for layer in layers:
layer.multiplier = multiplier
return self
def __call__(self, context, *args, **kwargs):
return self.forward(context, *args, **kwargs)
@ -436,9 +442,13 @@ def list_hypernetworks(path):
res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
idx = 0
while name in res:
idx += 1
name = name + f"({idx})"
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name+ f"({sd_models.model_hash(filename)})"] = filename
res[name] = filename
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
if name != "None":
@ -451,7 +461,10 @@ def find_closest_first(keyset, target):
return keys
return None
def load_hypernetwork(filename):
hypernetwork = None
path = shared.hypernetworks.get(filename, None)
if path is None:
filename = find_closest_first(shared.hypernetworks.keys(), filename)
@ -462,8 +475,12 @@ def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}")
if path.endswith(".pt"):
try:
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
hypernetwork = Hypernetwork()
hypernetwork.load(path)
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = hypernetwork
else:
return hypernetwork
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
@ -472,18 +489,23 @@ def load_hypernetwork(filename):
# Load Hypernetwork processing
try:
from .hypernetworks import load as load_hns
shared.loaded_hypernetwork = load_hns(path)
print(f"Loaded Hypernetwork Structure {path}")
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = load_hns(path)
else:
hypernetwork = load_hns(path)
print(f"Loaded Hypernetwork Structure {path}")
return hypernetwork
except Exception:
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
print(f"Tried to load unknown file extension: {filename}")
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
if hasattr(shared, 'loaded_hypernetwork'):
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
return hypernetwork
def apply_hypernetwork(hypernetwork, context, layer=None):
@ -504,266 +526,32 @@ def apply_hypernetwork(hypernetwork, context, layer=None):
return context_k, context_v
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
use_beta_scheduler=False, beta_repeat_epoch=4000,epoch_mult=1, warmup =10, min_lr=1e-7, gamma_rate=1):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
try:
if use_beta_scheduler:
print("Using Beta Scheduler")
beta_repeat_epoch = int(beta_repeat_epoch)
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
min_lr = float(min_lr)
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
gamma_rate = float(gamma_rate)
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
epoch_mult = int(float(epoch_mult))
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
warmup = int(warmup)
assert warmup >= 1, "Warmup epoch should be larger than 0!"
else:
beta_repeat_epoch = 4000
epoch_mult=1
warmup=10
min_lr=1e-7
gamma_rate=1
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, 1, template_file, steps,
save_hypernetwork_every, create_image_every, log_directory,
name="hypernetwork")
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
if hypernetwork is None:
return context_k, context_v
if isinstance(hypernetwork, Hypernetwork):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
losses_list = []
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
context_k = hypernetwork_layers[0](context_k)
context_v = hypernetwork_layers[1](context_v)
return context_k, context_v
context_k, context_v = hypernetwork(context_k, context_v, layer=layer)
return context_k, context_v
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name,
model=shared.sd_model, device=devices.device,
template_file=template_file, include_cond=True,
batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda: deque(maxlen=1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
scheduler_beta.last_epoch =hypernetwork.step-1
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
if use_beta_scheduler:
scheduler_beta.step(hypernetwork.step)
if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
if not use_beta_scheduler:
scheduler.apply(optimizer, hypernetwork.step)
if i + ititial_step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss_infos = shared.sd_model(x, c)[1]
loss = loss_infos[
'val/loss_simple'] # + loss_infos['val/loss_vlb'] * 0.4 #its 'prior class preserving' loss
del x
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
losses_list.append(loss.item())
for entry in entries:
loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
if weights[0].grad is None:
steps_without_grad += 1
else:
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
optimizer.step()
steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": optimizer.param_groups[0]['lr']
})
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
optim_to(optimizer, devices.cpu)
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format, processed.infotexts[0],
p=p, forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
optim_to(optimizer, devices.device)
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
hypernetwork.eval()
return hypernetwork, filename
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def apply_hypernetwork_strength(p, x, xs):
apply_strength(x)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
@ -778,9 +566,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'name') else shared.loaded_hypernetwork.name),
"Hypernet hash": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'filename') else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@ -801,13 +586,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'):
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
else:
modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork
if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'):
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'):
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
# Fix calculating hash for multiple hns
processing.create_infotext = create_infotext

View File

@ -4,6 +4,8 @@ import os.path
import torch
from modules import devices, shared
from .hnutil import find_self
from .shared import version_flag
lazy_load = False # when this is enabled, HNs will be loaded when required.
@ -79,6 +81,17 @@ class Forward:
def __call__(self, *args, **kwargs):
raise NotImplementedError
def set_multiplier(self, *args, **kwargs):
pass
def extra_name(self):
if version_flag:
return ""
found = find_self(self)
if found is not None:
return f" <hypernet:{found}:1.0>"
return f" <hypernet:{self.name}:1.0>"
@staticmethod
def parse(arg, name=None):
arg = Forward.unpack(arg)

View File

@ -2,11 +2,13 @@
from modules.shared import cmd_opts, opts
import modules.shared
version_flag = hasattr(modules.shared, 'loaded_hypernetwork')
def reload_hypernetworks():
from .hypernetwork import list_hypernetworks, load_hypernetwork
modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir)
load_hypernetwork(opts.sd_hypernetwork)
if hasattr(modules.shared, 'loaded_hypernetwork'):
load_hypernetwork(opts.sd_hypernetwork)
try:

View File

@ -1,8 +1,8 @@
import html
import os
from modules import shared, sd_hijack, devices
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
from modules import shared
from .hypernetwork import Hypernetwork, load_hypernetwork
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
weight_init_seed=None, normal_std=0.01, skip_connection=False):
@ -36,8 +36,8 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
)
hypernet.save(fn)
shared.reload_hypernetworks()
load_hypernetwork(fn)
hypernet = load_hypernetwork(name)
assert hypernet is not None, f"Cannot load from {name}!"
return hypernet
@ -76,27 +76,3 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
return name, f"Created: {fn}", ""
def train_hypernetwork_ui(*args):
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
sd_hijack.undo_optimizations()
hypernetwork, filename = train_hypernetwork(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
Hypernetwork saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()

View File

@ -17,95 +17,6 @@ from webui import wrap_gradio_gpu_call
setattr(shared.opts,'pin_memory', False)
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
with gr.Tab(label="Train_Beta") as train_beta:
gr.HTML(
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
"refresh_train_hypernetwork_name")
with gr.Row():
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
placeholder="Hypernetwork Learning rate", value="0.00001")
use_beta_scheduler_checkbox = gr.Checkbox(
label='Show advanced learn rate scheduler options(for Hypernetworks)')
with gr.Row(visible=False) as beta_scheduler_options:
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
epoch_mult = gr.Textbox(label='Epoch multiplier per cycle', placeholder="Cycles length multiplier every cycle", value="1")
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="1")
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
placeholder="restricts decay value, but does not restrict gamma rate decay",
value="1e-7")
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
placeholder="Value should be in (0-1]", value="1")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file',
value=os.path.join(script_path, "textual_inversion_templates",
"style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500,
precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable',
value=500, precision=0)
preview_from_txt2img = gr.Checkbox(
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
use_beta_scheduler_checkbox.change(
fn=lambda show: gr_show(show),
inputs=[use_beta_scheduler_checkbox],
outputs=[beta_scheduler_options],
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(ui.train_hypernetwork_ui, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
create_image_every,
save_embedding_every,
template_file,
preview_from_txt2img,
*params.txt2img_preview_params,
use_beta_scheduler,
beta_repeat_epoch,
epoch_mult,
warmup,
min_lr,
gamma_rate
],
outputs=[
ti_output,
ti_outcome,
]
)
return [(train_beta, "Train_beta", "train_beta")]
def create_extension_tab(params=None):
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
new_hypernetwork_name = gr.Textbox(label="Name")