commit
3dbb5abae2
|
|
@ -39,8 +39,12 @@ training_scheduler = Scheduler(cycle_step=-1, repeat=False)
|
|||
|
||||
def get_current(value, step=None):
|
||||
if step is None:
|
||||
if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None:
|
||||
return training_scheduler(value, shared.loaded_hypernetwork.step)
|
||||
if hasattr(shared, 'accessible_hypernetwork'):
|
||||
hypernetwork = shared.accessible_hypernetwork
|
||||
else:
|
||||
return value
|
||||
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
|
||||
return training_scheduler(value, hypernetwork.step)
|
||||
return value
|
||||
return max(1, training_scheduler(value, step))
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,18 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
|
|||
from ..ddpm_hijack import set_scheduler
|
||||
|
||||
|
||||
def set_accessible(obj):
|
||||
setattr(shared, 'accessible_hypernetwork', obj)
|
||||
if hasattr(shared, 'loaded_hypernetworks'):
|
||||
shared.loaded_hypernetworks.clear()
|
||||
shared.loaded_hypernetworks = [obj,]
|
||||
|
||||
|
||||
def remove_accessible():
|
||||
delattr(shared, 'accessible_hypernetwork')
|
||||
if hasattr(shared, 'loaded_hypernetworks'):
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
def get_training_option(filename):
|
||||
print(filename)
|
||||
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
|
||||
|
|
@ -43,6 +55,40 @@ def get_training_option(filename):
|
|||
return obj
|
||||
|
||||
|
||||
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1,use_adamw_parameter=False, **adamW_kwarg_dict):
|
||||
""" returns hypernetwork object binded with optimizer"""
|
||||
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork.to(devices.device)
|
||||
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||
if not isinstance(hypernetwork, Hypernetwork):
|
||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||
set_accessible(hypernetwork)
|
||||
weights = hypernetwork.weights(True)
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
if use_adamw_parameter:
|
||||
if hypernetwork.optimizer_name != 'AdamW':
|
||||
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
|
||||
else:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
optim_to(optimizer, devices.device)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
return hypernetwork, optimizer, weights, optimizer_name
|
||||
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
||||
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
||||
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
|
||||
|
|
@ -176,15 +222,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||
log_directory, name="hypernetwork")
|
||||
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0)
|
||||
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, **adamW_kwarg_dict)
|
||||
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
|
@ -204,7 +246,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
|
@ -251,29 +292,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
weights = hypernetwork.weights(True)
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
if use_adamw_parameter:
|
||||
if hypernetwork.optimizer_name != 'AdamW':
|
||||
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||
else:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
optim_to(optimizer, devices.device)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
if use_beta_scheduler:
|
||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
||||
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
||||
|
|
@ -421,7 +439,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.prompt = preview_prompt + hypernetwork.extra_name()
|
||||
print(p.prompt)
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
|
|
@ -430,7 +449,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
|
@ -465,6 +484,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
forced_filename=forced_filename,
|
||||
save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
set_accessible(hypernetwork)
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
|
|
@ -487,6 +507,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
set_scheduler(-1, False, False)
|
||||
remove_accessible()
|
||||
report_statistics(loss_dict)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
|
|
@ -536,11 +557,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
dropout_structure, optional_info, weight_init_seed, normal_std,
|
||||
skip_connection)
|
||||
else:
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
|
||||
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
|
||||
hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
|
||||
shared.reload_hypernetworks()
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||
if load_training_options != '':
|
||||
dump: dict = get_training_option(load_training_options)
|
||||
if dump and dump is not None:
|
||||
|
|
@ -660,11 +681,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||
log_directory, name="hypernetwork")
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
||||
hypernetwork.to(devices.device)
|
||||
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||
if not isinstance(hypernetwork, Hypernetwork):
|
||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||
|
||||
set_accessible(hypernetwork)
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
|
@ -687,7 +708,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
|
@ -711,7 +731,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
shared.sd_model.first_stage_model.requires_grad_(False)
|
||||
torch.cuda.empty_cache()
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||
height=training_height,
|
||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||
|
|
@ -755,10 +774,10 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
optim_to(optimizer, devices.device)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
optim_to(optimizer, devices.device)
|
||||
if use_beta_scheduler:
|
||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
||||
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
||||
|
|
@ -895,7 +914,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
hypernetwork.eval()
|
||||
if move_optimizer:
|
||||
optim_to(optimizer, devices.cpu)
|
||||
gc.collect()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
|
|
@ -906,7 +924,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.prompt = preview_prompt + hypernetwork.extra_name()
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
|
|
@ -915,7 +933,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
|
@ -950,6 +968,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
forced_filename=forced_filename,
|
||||
save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
set_accessible(hypernetwork)
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
|
|
@ -970,6 +989,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
hypernetwork.eval()
|
||||
set_scheduler(-1, False, False)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
remove_accessible()
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import random
|
|||
|
||||
from modules import shared, sd_hijack, devices
|
||||
from modules.call_queue import wrap_gradio_call
|
||||
from modules.hypernetworks.ui import keys
|
||||
from modules.paths import script_path
|
||||
from modules.ui import create_refresh_button, gr_show
|
||||
from webui import wrap_gradio_gpu_call
|
||||
|
|
@ -15,8 +14,11 @@ import gradio as gr
|
|||
|
||||
|
||||
def train_hypernetwork_ui(*args):
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
initial_hypernetwork = None
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
else:
|
||||
shared.loaded_hypernetworks = []
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
try:
|
||||
|
|
@ -32,14 +34,21 @@ Hypernetwork saved to {html.escape(filename)}
|
|||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
else:
|
||||
shared.loaded_hypernetworks = []
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
|
||||
def train_hypernetwork_ui_tuning(*args):
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
initial_hypernetwork = None
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
else:
|
||||
shared.loaded_hypernetworks = []
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
|
|
@ -55,7 +64,10 @@ Training {'interrupted' if shared.state.interrupted else 'finished'}.
|
|||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
else:
|
||||
shared.loaded_hypernetworks = []
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,14 @@
|
|||
import torch
|
||||
|
||||
import modules.shared
|
||||
|
||||
|
||||
def find_self(self):
|
||||
for k, v in modules.shared.hypernetworks.items():
|
||||
if v == self:
|
||||
return k
|
||||
return None
|
||||
|
||||
|
||||
def optim_to(optim:torch.optim.Optimizer, device="cpu"):
|
||||
def inplace_move(obj: torch.Tensor, target):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,10 @@
|
|||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import defaultdict, deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
|
||||
|
||||
import scripts.xy_grid
|
||||
|
|
@ -20,16 +15,10 @@ except (ImportError, ModuleNotFoundError):
|
|||
print("modules.hashes is not found, will use backup module from extension!")
|
||||
from .hashes_backup import sha256
|
||||
|
||||
from .scheduler import CosineAnnealingWarmUpRestarts
|
||||
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, shared, sd_models, processing, sd_samplers, generation_parameters_copypaste
|
||||
from .hnutil import parse_dropout_structure, optim_to
|
||||
from modules.hypernetworks.hypernetwork import report_statistics, save_hypernetwork, stack_conds, optimizer_dict
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from .dataset import PersonalizedBase
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
from modules import devices, shared, sd_models, processing, generation_parameters_copypaste
|
||||
from .hnutil import parse_dropout_structure, find_self
|
||||
from .shared import version_flag
|
||||
|
||||
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
|
|
@ -217,10 +206,10 @@ class HypernetworkModule(torch.nn.Module):
|
|||
resnet_result = self.linear(x)
|
||||
residual = resnet_result - x
|
||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||
multiplier = HypernetworkModule.multiplier
|
||||
return x + multiplier * residual # interpolate
|
||||
multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
|
||||
return x + multiplier * residual # interpolate
|
||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
|
||||
return x + self.linear(x) * multiplier
|
||||
|
||||
def trainables(self, train=False):
|
||||
|
|
@ -317,6 +306,14 @@ class Hypernetwork:
|
|||
sha256v = sha256(self.filename, f'hypernet/{self.name}')
|
||||
return sha256v[0:10]
|
||||
|
||||
def extra_name(self):
|
||||
if version_flag:
|
||||
return ""
|
||||
found = find_self(self)
|
||||
if found is not None:
|
||||
return f" <hypernet:{found}:1.0>"
|
||||
return f" <hypernet:{self.name}:1.0>"
|
||||
|
||||
def save(self, filename):
|
||||
state_dict = {}
|
||||
optimizer_saved_dict = {}
|
||||
|
|
@ -412,9 +409,18 @@ class Hypernetwork:
|
|||
self.eval()
|
||||
|
||||
def to(self, device):
|
||||
for values in self.layers.values():
|
||||
values[0].to(device)
|
||||
values[1].to(device)
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def __call__(self, context, *args, **kwargs):
|
||||
return self.forward(context, *args, **kwargs)
|
||||
|
|
@ -436,9 +442,13 @@ def list_hypernetworks(path):
|
|||
res = {}
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
idx = 0
|
||||
while name in res:
|
||||
idx += 1
|
||||
name = name + f"({idx})"
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name+ f"({sd_models.model_hash(filename)})"] = filename
|
||||
res[name] = filename
|
||||
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
if name != "None":
|
||||
|
|
@ -451,7 +461,10 @@ def find_closest_first(keyset, target):
|
|||
return keys
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
hypernetwork = None
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
if path is None:
|
||||
filename = find_closest_first(shared.hypernetworks.keys(), filename)
|
||||
|
|
@ -462,8 +475,12 @@ def load_hypernetwork(filename):
|
|||
print(f"Loading hypernetwork {filename}")
|
||||
if path.endswith(".pt"):
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
shared.loaded_hypernetwork = hypernetwork
|
||||
else:
|
||||
return hypernetwork
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
|
|
@ -472,18 +489,23 @@ def load_hypernetwork(filename):
|
|||
# Load Hypernetwork processing
|
||||
try:
|
||||
from .hypernetworks import load as load_hns
|
||||
shared.loaded_hypernetwork = load_hns(path)
|
||||
print(f"Loaded Hypernetwork Structure {path}")
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
shared.loaded_hypernetwork = load_hns(path)
|
||||
else:
|
||||
hypernetwork = load_hns(path)
|
||||
print(f"Loaded Hypernetwork Structure {path}")
|
||||
return hypernetwork
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
print(f"Tried to load unknown file extension: {filename}")
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print(f"Unloading hypernetwork")
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
if hasattr(shared, 'loaded_hypernetwork'):
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print(f"Unloading hypernetwork")
|
||||
shared.loaded_hypernetwork = None
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
|
|
@ -504,266 +526,32 @@ def apply_hypernetwork(hypernetwork, context, layer=None):
|
|||
return context_k, context_v
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
|
||||
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
|
||||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
||||
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||
use_beta_scheduler=False, beta_repeat_epoch=4000,epoch_mult=1, warmup =10, min_lr=1e-7, gamma_rate=1):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
try:
|
||||
if use_beta_scheduler:
|
||||
print("Using Beta Scheduler")
|
||||
beta_repeat_epoch = int(beta_repeat_epoch)
|
||||
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
|
||||
min_lr = float(min_lr)
|
||||
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
|
||||
gamma_rate = float(gamma_rate)
|
||||
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
|
||||
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
|
||||
epoch_mult = int(float(epoch_mult))
|
||||
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
|
||||
warmup = int(warmup)
|
||||
assert warmup >= 1, "Warmup epoch should be larger than 0!"
|
||||
else:
|
||||
beta_repeat_epoch = 4000
|
||||
epoch_mult=1
|
||||
warmup=10
|
||||
min_lr=1e-7
|
||||
gamma_rate=1
|
||||
except ValueError:
|
||||
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, 1, template_file, steps,
|
||||
save_hypernetwork_every, create_image_every, log_directory,
|
||||
name="hypernetwork")
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
if hypernetwork is None:
|
||||
return context_k, context_v
|
||||
if isinstance(hypernetwork, Hypernetwork):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
if hypernetwork_layers is None:
|
||||
return context_k, context_v
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
losses_list = []
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
return context_k, context_v
|
||||
context_k, context_v = hypernetwork(context_k, context_v, layer=layer)
|
||||
return context_k, context_v
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||
unload = shared.opts.unload_models_when_training
|
||||
|
||||
if save_hypernetwork_every > 0:
|
||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||
os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||
else:
|
||||
hypernetwork_dir = None
|
||||
|
||||
if create_image_every > 0:
|
||||
images_dir = os.path.join(log_directory, "images")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||
height=training_height,
|
||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||
placeholder_token=hypernetwork_name,
|
||||
model=shared.sd_model, device=devices.device,
|
||||
template_file=template_file, include_cond=True,
|
||||
batch_size=batch_size)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
size = len(ds.indexes)
|
||||
loss_dict = defaultdict(lambda: deque(maxlen=1024))
|
||||
losses = torch.zeros((size,))
|
||||
previous_mean_losses = [0]
|
||||
previous_mean_loss = 0
|
||||
print("Mean loss of {} elements".format(size))
|
||||
|
||||
weights = hypernetwork.weights(True)
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
|
||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
|
||||
scheduler_beta.last_epoch =hypernetwork.step-1
|
||||
steps_without_grad = 0
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
if use_beta_scheduler:
|
||||
scheduler_beta.step(hypernetwork.step)
|
||||
if len(loss_dict) > 0:
|
||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||
previous_mean_loss = mean(previous_mean_losses)
|
||||
if not use_beta_scheduler:
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if i + ititial_step > steps:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss_infos = shared.sd_model(x, c)[1]
|
||||
loss = loss_infos[
|
||||
'val/loss_simple'] # + loss_infos['val/loss_vlb'] * 0.4 #its 'prior class preserving' loss
|
||||
del x
|
||||
del c
|
||||
|
||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||
losses_list.append(loss.item())
|
||||
for entry in entries:
|
||||
loss_dict[entry.filename].append(loss.item())
|
||||
optimizer.zero_grad()
|
||||
weights[0].grad = None
|
||||
loss.backward()
|
||||
|
||||
if weights[0].grad is None:
|
||||
steps_without_grad += 1
|
||||
else:
|
||||
steps_without_grad = 0
|
||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||
optimizer.step()
|
||||
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
|
||||
if len(previous_mean_losses) > 1:
|
||||
std = stdev(previous_mean_losses)
|
||||
else:
|
||||
std = 0
|
||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||
pbar.set_description(dataset_loss_info)
|
||||
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{previous_mean_loss:.7f}",
|
||||
"learn_rate": optimizer.param_groups[0]['lr']
|
||||
})
|
||||
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
optimizer.zero_grad()
|
||||
optim_to(optimizer, devices.cpu)
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||
shared.opts.samples_format, processed.infotexts[0],
|
||||
p=p, forced_filename=forced_filename,
|
||||
save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
optim_to(optimizer, devices.device)
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {previous_mean_loss:.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
report_statistics(loss_dict)
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
del optimizer
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
hypernetwork.eval()
|
||||
return hypernetwork, filename
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
|
||||
def apply_hypernetwork_strength(p, x, xs):
|
||||
apply_strength(x)
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
|
|
@ -778,9 +566,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'name') else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'filename') else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
|
|
@ -801,13 +586,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
|
||||
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
|
||||
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
|
||||
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
|
||||
if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'):
|
||||
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
|
||||
else:
|
||||
modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork
|
||||
if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'):
|
||||
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
|
||||
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
|
||||
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
|
||||
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
|
||||
if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'):
|
||||
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
|
||||
|
||||
# Fix calculating hash for multiple hns
|
||||
processing.create_infotext = create_infotext
|
||||
|
|
@ -4,6 +4,8 @@ import os.path
|
|||
import torch
|
||||
|
||||
from modules import devices, shared
|
||||
from .hnutil import find_self
|
||||
from .shared import version_flag
|
||||
|
||||
lazy_load = False # when this is enabled, HNs will be loaded when required.
|
||||
|
||||
|
|
@ -79,6 +81,17 @@ class Forward:
|
|||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_multiplier(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def extra_name(self):
|
||||
if version_flag:
|
||||
return ""
|
||||
found = find_self(self)
|
||||
if found is not None:
|
||||
return f" <hypernet:{found}:1.0>"
|
||||
return f" <hypernet:{self.name}:1.0>"
|
||||
|
||||
@staticmethod
|
||||
def parse(arg, name=None):
|
||||
arg = Forward.unpack(arg)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@
|
|||
from modules.shared import cmd_opts, opts
|
||||
import modules.shared
|
||||
|
||||
version_flag = hasattr(modules.shared, 'loaded_hypernetwork')
|
||||
|
||||
def reload_hypernetworks():
|
||||
from .hypernetwork import list_hypernetworks, load_hypernetwork
|
||||
modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
load_hypernetwork(opts.sd_hypernetwork)
|
||||
if hasattr(modules.shared, 'loaded_hypernetwork'):
|
||||
load_hypernetwork(opts.sd_hypernetwork)
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import html
|
||||
import os
|
||||
|
||||
from modules import shared, sd_hijack, devices
|
||||
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
|
||||
from modules import shared
|
||||
from .hypernetwork import Hypernetwork, load_hypernetwork
|
||||
|
||||
|
||||
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||
weight_init_seed=None, normal_std=0.01, skip_connection=False):
|
||||
|
|
@ -36,8 +36,8 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
|
|||
)
|
||||
hypernet.save(fn)
|
||||
shared.reload_hypernetworks()
|
||||
load_hypernetwork(fn)
|
||||
|
||||
hypernet = load_hypernetwork(name)
|
||||
assert hypernet is not None, f"Cannot load from {name}!"
|
||||
return hypernet
|
||||
|
||||
|
||||
|
|
@ -76,27 +76,3 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
shared.reload_hypernetworks()
|
||||
|
||||
return name, f"Created: {fn}", ""
|
||||
|
||||
def train_hypernetwork_ui(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
hypernetwork, filename = train_hypernetwork(*args)
|
||||
|
||||
res = f"""
|
||||
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||
Hypernetwork saved to {html.escape(filename)}
|
||||
"""
|
||||
return res, ""
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
|
@ -17,95 +17,6 @@ from webui import wrap_gradio_gpu_call
|
|||
|
||||
setattr(shared.opts,'pin_memory', False)
|
||||
|
||||
|
||||
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
||||
with gr.Tab(label="Train_Beta") as train_beta:
|
||||
gr.HTML(
|
||||
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with gr.Row():
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
|
||||
choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
||||
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
||||
"refresh_train_hypernetwork_name")
|
||||
with gr.Row():
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
|
||||
placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||
use_beta_scheduler_checkbox = gr.Checkbox(
|
||||
label='Show advanced learn rate scheduler options(for Hypernetworks)')
|
||||
with gr.Row(visible=False) as beta_scheduler_options:
|
||||
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
|
||||
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
|
||||
epoch_mult = gr.Textbox(label='Epoch multiplier per cycle', placeholder="Cycles length multiplier every cycle", value="1")
|
||||
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="1")
|
||||
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
|
||||
placeholder="restricts decay value, but does not restrict gamma rate decay",
|
||||
value="1e-7")
|
||||
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
|
||||
placeholder="Value should be in (0-1]", value="1")
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
||||
value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file',
|
||||
value=os.path.join(script_path, "textual_inversion_templates",
|
||||
"style_filewords.txt"))
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500,
|
||||
precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable',
|
||||
value=500, precision=0)
|
||||
preview_from_txt2img = gr.Checkbox(
|
||||
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
||||
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
||||
use_beta_scheduler_checkbox.change(
|
||||
fn=lambda show: gr_show(show),
|
||||
inputs=[use_beta_scheduler_checkbox],
|
||||
outputs=[beta_scheduler_options],
|
||||
)
|
||||
interrupt_training.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
train_hypernetwork.click(
|
||||
fn=wrap_gradio_gpu_call(ui.train_hypernetwork_ui, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_hypernetwork_name,
|
||||
hypernetwork_learn_rate,
|
||||
batch_size,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
preview_from_txt2img,
|
||||
*params.txt2img_preview_params,
|
||||
use_beta_scheduler,
|
||||
beta_repeat_epoch,
|
||||
epoch_mult,
|
||||
warmup,
|
||||
min_lr,
|
||||
gamma_rate
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
return [(train_beta, "Train_beta", "train_beta")]
|
||||
|
||||
def create_extension_tab(params=None):
|
||||
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
|
|
|
|||
Loading…
Reference in New Issue