Merge pull request #37 from aria1th/latest-fix

Latest fix
beta-dadaptation
AngelBottomless 2023-01-21 21:55:46 +09:00 committed by GitHub
commit 3dbb5abae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 458 deletions

View File

@ -39,8 +39,12 @@ training_scheduler = Scheduler(cycle_step=-1, repeat=False)
def get_current(value, step=None): def get_current(value, step=None):
if step is None: if step is None:
if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None: if hasattr(shared, 'accessible_hypernetwork'):
return training_scheduler(value, shared.loaded_hypernetwork.step) hypernetwork = shared.accessible_hypernetwork
else:
return value
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
return training_scheduler(value, hypernetwork.step)
return value return value
return max(1, training_scheduler(value, step)) return max(1, training_scheduler(value, step))

View File

@ -25,6 +25,18 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
from ..ddpm_hijack import set_scheduler from ..ddpm_hijack import set_scheduler
def set_accessible(obj):
setattr(shared, 'accessible_hypernetwork', obj)
if hasattr(shared, 'loaded_hypernetworks'):
shared.loaded_hypernetworks.clear()
shared.loaded_hypernetworks = [obj,]
def remove_accessible():
delattr(shared, 'accessible_hypernetwork')
if hasattr(shared, 'loaded_hypernetworks'):
shared.loaded_hypernetworks.clear()
def get_training_option(filename): def get_training_option(filename):
print(filename) print(filename)
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile( if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
@ -43,6 +55,40 @@ def get_training_option(filename):
return obj return obj
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1,use_adamw_parameter=False, **adamW_kwarg_dict):
""" returns hypernetwork object binded with optimizer"""
hypernetwork = load_hypernetwork(hypernetwork_name)
hypernetwork.to(devices.device)
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
set_accessible(hypernetwork)
weights = hypernetwork.weights(True)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
if use_adamw_parameter:
if hypernetwork.optimizer_name != 'AdamW':
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
else:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
return hypernetwork, optimizer, weights, optimizer_name
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
@ -176,15 +222,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every, template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork") log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.job = "train-hypernetwork" shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0)
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, **adamW_kwarg_dict)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
@ -204,7 +246,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else: else:
images_dir = None images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
@ -251,29 +292,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
if use_adamw_parameter:
if hypernetwork.optimizer_name != 'AdamW':
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
else:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
if use_beta_scheduler: if use_beta_scheduler:
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
@ -421,7 +439,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
) )
if preview_from_txt2img: if preview_from_txt2img:
p.prompt = preview_prompt p.prompt = preview_prompt + hypernetwork.extra_name()
print(p.prompt)
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
@ -430,7 +449,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -465,6 +484,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
forced_filename=forced_filename, forced_filename=forced_filename,
save_to_dirs=False) save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
set_accessible(hypernetwork)
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -487,6 +507,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if hasattr(sd_hijack_checkpoint, 'remove'): if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()
set_scheduler(-1, False, False) set_scheduler(-1, False, False)
remove_accessible()
report_statistics(loss_dict) report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
@ -536,11 +557,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
dropout_structure, optional_info, weight_init_seed, normal_std, dropout_structure, optional_info, weight_init_seed, normal_std,
skip_connection) skip_connection)
else: else:
load_hypernetwork(hypernetwork_name) hypernetwork = load_hypernetwork(hypernetwork_name)
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt")) hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
shared.reload_hypernetworks() shared.reload_hypernetworks()
load_hypernetwork(hypernetwork_name) hypernetwork = load_hypernetwork(hypernetwork_name)
if load_training_options != '': if load_training_options != '':
dump: dict = get_training_option(load_training_options) dump: dict = get_training_option(load_training_options)
if dump and dump is not None: if dump and dump is not None:
@ -660,11 +681,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every, template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork") log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name) hypernetwork.to(devices.device)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!" assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork): if not isinstance(hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
set_accessible(hypernetwork)
shared.state.job = "train-hypernetwork" shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
@ -687,7 +708,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
else: else:
images_dir = None images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
@ -711,7 +731,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
shared.sd_model.first_stage_model.requires_grad_(False) shared.sd_model.first_stage_model.requires_grad_(False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
pin_memory = shared.opts.pin_memory pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width, ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height, height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch, repeats=shared.opts.training_image_repeats_per_epoch,
@ -755,10 +774,10 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try: try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict) optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
optim_to(optimizer, devices.device)
except RuntimeError as e: except RuntimeError as e:
print("Cannot resume from saved optimizer!") print("Cannot resume from saved optimizer!")
print(e) print(e)
optim_to(optimizer, devices.device)
if use_beta_scheduler: if use_beta_scheduler:
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
@ -895,7 +914,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
hypernetwork.eval() hypernetwork.eval()
if move_optimizer: if move_optimizer:
optim_to(optimizer, devices.cpu) optim_to(optimizer, devices.cpu)
gc.collect()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
@ -906,7 +924,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
) )
if preview_from_txt2img: if preview_from_txt2img:
p.prompt = preview_prompt p.prompt = preview_prompt + hypernetwork.extra_name()
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
@ -915,7 +933,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -950,6 +968,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
forced_filename=forced_filename, forced_filename=forced_filename,
save_to_dirs=False) save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
set_accessible(hypernetwork)
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -970,6 +989,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
hypernetwork.eval() hypernetwork.eval()
set_scheduler(-1, False, False) set_scheduler(-1, False, False)
shared.parallel_processing_allowed = old_parallel_processing_allowed shared.parallel_processing_allowed = old_parallel_processing_allowed
remove_accessible()
if hasattr(sd_hijack_checkpoint, 'remove'): if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()
if shared.opts.training_enable_tensorboard: if shared.opts.training_enable_tensorboard:

View File

@ -5,7 +5,6 @@ import random
from modules import shared, sd_hijack, devices from modules import shared, sd_hijack, devices
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call
from modules.hypernetworks.ui import keys
from modules.paths import script_path from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show from modules.ui import create_refresh_button, gr_show
from webui import wrap_gradio_gpu_call from webui import wrap_gradio_gpu_call
@ -15,8 +14,11 @@ import gradio as gr
def train_hypernetwork_ui(*args): def train_hypernetwork_ui(*args):
initial_hypernetwork = shared.loaded_hypernetwork initial_hypernetwork = None
if hasattr(shared, 'loaded_hypernetwork'):
initial_hypernetwork = shared.loaded_hypernetwork
else:
shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try: try:
@ -32,14 +34,21 @@ Hypernetwork saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
shared.loaded_hypernetwork = initial_hypernetwork if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = initial_hypernetwork
else:
shared.loaded_hypernetworks = []
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
def train_hypernetwork_ui_tuning(*args): def train_hypernetwork_ui_tuning(*args):
initial_hypernetwork = shared.loaded_hypernetwork initial_hypernetwork = None
if hasattr(shared, 'loaded_hypernetwork'):
initial_hypernetwork = shared.loaded_hypernetwork
else:
shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@ -55,7 +64,10 @@ Training {'interrupted' if shared.state.interrupted else 'finished'}.
except Exception: except Exception:
raise raise
finally: finally:
shared.loaded_hypernetwork = initial_hypernetwork if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = initial_hypernetwork
else:
shared.loaded_hypernetworks = []
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()

View File

@ -1,5 +1,14 @@
import torch import torch
import modules.shared
def find_self(self):
for k, v in modules.shared.hypernetworks.items():
if v == self:
return k
return None
def optim_to(optim:torch.optim.Optimizer, device="cpu"): def optim_to(optim:torch.optim.Optimizer, device="cpu"):
def inplace_move(obj: torch.Tensor, target): def inplace_move(obj: torch.Tensor, target):

View File

@ -1,15 +1,10 @@
import datetime
import glob import glob
import html
import inspect import inspect
import os import os
import sys import sys
import traceback import traceback
from collections import defaultdict, deque
from statistics import stdev, mean
import torch import torch
import tqdm
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_ from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
import scripts.xy_grid import scripts.xy_grid
@ -20,16 +15,10 @@ except (ImportError, ModuleNotFoundError):
print("modules.hashes is not found, will use backup module from extension!") print("modules.hashes is not found, will use backup module from extension!")
from .hashes_backup import sha256 from .hashes_backup import sha256
from .scheduler import CosineAnnealingWarmUpRestarts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
from modules import devices, shared, sd_models, processing, sd_samplers, generation_parameters_copypaste from modules import devices, shared, sd_models, processing, generation_parameters_copypaste
from .hnutil import parse_dropout_structure, optim_to from .hnutil import parse_dropout_structure, find_self
from modules.hypernetworks.hypernetwork import report_statistics, save_hypernetwork, stack_conds, optimizer_dict from .shared import version_flag
from modules.textual_inversion import textual_inversion
from .dataset import PersonalizedBase
from modules.textual_inversion.learn_schedule import LearnRateScheduler
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"): def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
w, b = layer.weight.data, layer.bias.data w, b = layer.weight.data, layer.bias.data
@ -217,10 +206,10 @@ class HypernetworkModule(torch.nn.Module):
resnet_result = self.linear(x) resnet_result = self.linear(x)
residual = resnet_result - x residual = resnet_result - x
if multiplier is None or not isinstance(multiplier, (int, float)): if multiplier is None or not isinstance(multiplier, (int, float)):
multiplier = HypernetworkModule.multiplier multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
return x + multiplier * residual # interpolate return x + multiplier * residual # interpolate
if multiplier is None or not isinstance(multiplier, (int, float)): if multiplier is None or not isinstance(multiplier, (int, float)):
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
return x + self.linear(x) * multiplier return x + self.linear(x) * multiplier
def trainables(self, train=False): def trainables(self, train=False):
@ -317,6 +306,14 @@ class Hypernetwork:
sha256v = sha256(self.filename, f'hypernet/{self.name}') sha256v = sha256(self.filename, f'hypernet/{self.name}')
return sha256v[0:10] return sha256v[0:10]
def extra_name(self):
if version_flag:
return ""
found = find_self(self)
if found is not None:
return f" <hypernet:{found}:1.0>"
return f" <hypernet:{self.name}:1.0>"
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
optimizer_saved_dict = {} optimizer_saved_dict = {}
@ -412,9 +409,18 @@ class Hypernetwork:
self.eval() self.eval()
def to(self, device): def to(self, device):
for values in self.layers.values(): for k, layers in self.layers.items():
values[0].to(device) for layer in layers:
values[1].to(device) layer.to(device)
return self
def set_multiplier(self, multiplier):
for k, layers in self.layers.items():
for layer in layers:
layer.multiplier = multiplier
return self
def __call__(self, context, *args, **kwargs): def __call__(self, context, *args, **kwargs):
return self.forward(context, *args, **kwargs) return self.forward(context, *args, **kwargs)
@ -436,9 +442,13 @@ def list_hypernetworks(path):
res = {} res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
idx = 0
while name in res:
idx += 1
name = name + f"({idx})"
# Prevent a hypothetical "None.pt" from being listed. # Prevent a hypothetical "None.pt" from being listed.
if name != "None": if name != "None":
res[name+ f"({sd_models.model_hash(filename)})"] = filename res[name] = filename
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True): for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
if name != "None": if name != "None":
@ -451,7 +461,10 @@ def find_closest_first(keyset, target):
return keys return keys
return None return None
def load_hypernetwork(filename): def load_hypernetwork(filename):
hypernetwork = None
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if path is None: if path is None:
filename = find_closest_first(shared.hypernetworks.keys(), filename) filename = find_closest_first(shared.hypernetworks.keys(), filename)
@ -462,8 +475,12 @@ def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}") print(f"Loading hypernetwork {filename}")
if path.endswith(".pt"): if path.endswith(".pt"):
try: try:
shared.loaded_hypernetwork = Hypernetwork() hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path) hypernetwork.load(path)
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = hypernetwork
else:
return hypernetwork
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
@ -472,18 +489,23 @@ def load_hypernetwork(filename):
# Load Hypernetwork processing # Load Hypernetwork processing
try: try:
from .hypernetworks import load as load_hns from .hypernetworks import load as load_hns
shared.loaded_hypernetwork = load_hns(path) if hasattr(shared, 'loaded_hypernetwork'):
print(f"Loaded Hypernetwork Structure {path}") shared.loaded_hypernetwork = load_hns(path)
else:
hypernetwork = load_hns(path)
print(f"Loaded Hypernetwork Structure {path}")
return hypernetwork
except Exception: except Exception:
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr) print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
print(f"Tried to load unknown file extension: {filename}") print(f"Tried to load unknown file extension: {filename}")
else: else:
if shared.loaded_hypernetwork is not None: if hasattr(shared, 'loaded_hypernetwork'):
print(f"Unloading hypernetwork") if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
return hypernetwork
def apply_hypernetwork(hypernetwork, context, layer=None): def apply_hypernetwork(hypernetwork, context, layer=None):
@ -504,266 +526,32 @@ def apply_hypernetwork(hypernetwork, context, layer=None):
return context_k, context_v return context_k, context_v
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
training_height, steps, create_image_every, save_hypernetwork_every, template_file, if hypernetwork is None:
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, return context_k, context_v
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, if isinstance(hypernetwork, Hypernetwork):
use_beta_scheduler=False, beta_repeat_epoch=4000,epoch_mult=1, warmup =10, min_lr=1e-7, gamma_rate=1): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. if hypernetwork_layers is None:
from modules import images return context_k, context_v
try: if layer is not None:
if use_beta_scheduler: layer.hyper_k = hypernetwork_layers[0]
print("Using Beta Scheduler") layer.hyper_v = hypernetwork_layers[1]
beta_repeat_epoch = int(beta_repeat_epoch)
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
min_lr = float(min_lr)
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
gamma_rate = float(gamma_rate)
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
epoch_mult = int(float(epoch_mult))
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
warmup = int(warmup)
assert warmup >= 1, "Warmup epoch should be larger than 0!"
else:
beta_repeat_epoch = 4000
epoch_mult=1
warmup=10
min_lr=1e-7
gamma_rate=1
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, 1, template_file, steps,
save_hypernetwork_every, create_image_every, log_directory,
name="hypernetwork")
load_hypernetwork(hypernetwork_name) context_k = hypernetwork_layers[0](context_k)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!" context_v = hypernetwork_layers[1](context_v)
if not isinstance(shared.loaded_hypernetwork, Hypernetwork): return context_k, context_v
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") context_k, context_v = hypernetwork(context_k, context_v, layer=layer)
shared.state.textinfo = "Initializing hypernetwork training..." return context_k, context_v
shared.state.job_count = steps
losses_list = []
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name,
model=shared.sd_model, device=devices.device,
template_file=template_file, include_cond=True,
batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda: deque(maxlen=1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
scheduler_beta.last_epoch =hypernetwork.step-1
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
if use_beta_scheduler:
scheduler_beta.step(hypernetwork.step)
if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
if not use_beta_scheduler:
scheduler.apply(optimizer, hypernetwork.step)
if i + ititial_step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss_infos = shared.sd_model(x, c)[1]
loss = loss_infos[
'val/loss_simple'] # + loss_infos['val/loss_vlb'] * 0.4 #its 'prior class preserving' loss
del x
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
losses_list.append(loss.item())
for entry in entries:
loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
if weights[0].grad is None:
steps_without_grad += 1
else:
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
optimizer.step()
steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": optimizer.param_groups[0]['lr']
})
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
optim_to(optimizer, devices.cpu)
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format, processed.infotexts[0],
p=p, forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
optim_to(optimizer, devices.device)
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
hypernetwork.eval()
return hypernetwork, filename
def apply_strength(value=None): def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def apply_hypernetwork_strength(p, x, xs): def apply_hypernetwork_strength(p, x, xs):
apply_strength(x) apply_strength(x)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
@ -778,9 +566,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'name') else shared.loaded_hypernetwork.name),
"Hypernet hash": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'filename') else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@ -801,13 +586,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'):
modules.hypernetworks.hypernetwork.apply_strength = apply_strength modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
else:
modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork
if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'):
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'):
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
# Fix calculating hash for multiple hns # Fix calculating hash for multiple hns
processing.create_infotext = create_infotext processing.create_infotext = create_infotext

View File

@ -4,6 +4,8 @@ import os.path
import torch import torch
from modules import devices, shared from modules import devices, shared
from .hnutil import find_self
from .shared import version_flag
lazy_load = False # when this is enabled, HNs will be loaded when required. lazy_load = False # when this is enabled, HNs will be loaded when required.
@ -79,6 +81,17 @@ class Forward:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def set_multiplier(self, *args, **kwargs):
pass
def extra_name(self):
if version_flag:
return ""
found = find_self(self)
if found is not None:
return f" <hypernet:{found}:1.0>"
return f" <hypernet:{self.name}:1.0>"
@staticmethod @staticmethod
def parse(arg, name=None): def parse(arg, name=None):
arg = Forward.unpack(arg) arg = Forward.unpack(arg)

View File

@ -2,11 +2,13 @@
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
import modules.shared import modules.shared
version_flag = hasattr(modules.shared, 'loaded_hypernetwork')
def reload_hypernetworks(): def reload_hypernetworks():
from .hypernetwork import list_hypernetworks, load_hypernetwork from .hypernetwork import list_hypernetworks, load_hypernetwork
modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir) modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir)
load_hypernetwork(opts.sd_hypernetwork) if hasattr(modules.shared, 'loaded_hypernetwork'):
load_hypernetwork(opts.sd_hypernetwork)
try: try:

View File

@ -1,8 +1,8 @@
import html
import os import os
from modules import shared, sd_hijack, devices from modules import shared
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork from .hypernetwork import Hypernetwork, load_hypernetwork
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None, def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
weight_init_seed=None, normal_std=0.01, skip_connection=False): weight_init_seed=None, normal_std=0.01, skip_connection=False):
@ -36,8 +36,8 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
) )
hypernet.save(fn) hypernet.save(fn)
shared.reload_hypernetworks() shared.reload_hypernetworks()
load_hypernetwork(fn) hypernet = load_hypernetwork(name)
assert hypernet is not None, f"Cannot load from {name}!"
return hypernet return hypernet
@ -76,27 +76,3 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks() shared.reload_hypernetworks()
return name, f"Created: {fn}", "" return name, f"Created: {fn}", ""
def train_hypernetwork_ui(*args):
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
sd_hijack.undo_optimizations()
hypernetwork, filename = train_hypernetwork(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
Hypernetwork saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()

View File

@ -17,95 +17,6 @@ from webui import wrap_gradio_gpu_call
setattr(shared.opts,'pin_memory', False) setattr(shared.opts,'pin_memory', False)
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
with gr.Tab(label="Train_Beta") as train_beta:
gr.HTML(
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
"refresh_train_hypernetwork_name")
with gr.Row():
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
placeholder="Hypernetwork Learning rate", value="0.00001")
use_beta_scheduler_checkbox = gr.Checkbox(
label='Show advanced learn rate scheduler options(for Hypernetworks)')
with gr.Row(visible=False) as beta_scheduler_options:
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
epoch_mult = gr.Textbox(label='Epoch multiplier per cycle', placeholder="Cycles length multiplier every cycle", value="1")
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="1")
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
placeholder="restricts decay value, but does not restrict gamma rate decay",
value="1e-7")
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
placeholder="Value should be in (0-1]", value="1")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file',
value=os.path.join(script_path, "textual_inversion_templates",
"style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500,
precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable',
value=500, precision=0)
preview_from_txt2img = gr.Checkbox(
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
use_beta_scheduler_checkbox.change(
fn=lambda show: gr_show(show),
inputs=[use_beta_scheduler_checkbox],
outputs=[beta_scheduler_options],
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(ui.train_hypernetwork_ui, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
create_image_every,
save_embedding_every,
template_file,
preview_from_txt2img,
*params.txt2img_preview_params,
use_beta_scheduler,
beta_repeat_epoch,
epoch_mult,
warmup,
min_lr,
gamma_rate
],
outputs=[
ti_output,
ti_outcome,
]
)
return [(train_beta, "Train_beta", "train_beta")]
def create_extension_tab(params=None): def create_extension_tab(params=None):
with gr.Tab(label="Create Beta hypernetwork") as create_beta: with gr.Tab(label="Create Beta hypernetwork") as create_beta:
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")