commit
3dbb5abae2
|
|
@ -39,8 +39,12 @@ training_scheduler = Scheduler(cycle_step=-1, repeat=False)
|
||||||
|
|
||||||
def get_current(value, step=None):
|
def get_current(value, step=None):
|
||||||
if step is None:
|
if step is None:
|
||||||
if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None:
|
if hasattr(shared, 'accessible_hypernetwork'):
|
||||||
return training_scheduler(value, shared.loaded_hypernetwork.step)
|
hypernetwork = shared.accessible_hypernetwork
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
|
||||||
|
return training_scheduler(value, hypernetwork.step)
|
||||||
return value
|
return value
|
||||||
return max(1, training_scheduler(value, step))
|
return max(1, training_scheduler(value, step))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,18 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
|
||||||
from ..ddpm_hijack import set_scheduler
|
from ..ddpm_hijack import set_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def set_accessible(obj):
|
||||||
|
setattr(shared, 'accessible_hypernetwork', obj)
|
||||||
|
if hasattr(shared, 'loaded_hypernetworks'):
|
||||||
|
shared.loaded_hypernetworks.clear()
|
||||||
|
shared.loaded_hypernetworks = [obj,]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_accessible():
|
||||||
|
delattr(shared, 'accessible_hypernetwork')
|
||||||
|
if hasattr(shared, 'loaded_hypernetworks'):
|
||||||
|
shared.loaded_hypernetworks.clear()
|
||||||
|
|
||||||
def get_training_option(filename):
|
def get_training_option(filename):
|
||||||
print(filename)
|
print(filename)
|
||||||
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
|
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
|
||||||
|
|
@ -43,6 +55,40 @@ def get_training_option(filename):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1,use_adamw_parameter=False, **adamW_kwarg_dict):
|
||||||
|
""" returns hypernetwork object binded with optimizer"""
|
||||||
|
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||||
|
hypernetwork.to(devices.device)
|
||||||
|
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||||
|
if not isinstance(hypernetwork, Hypernetwork):
|
||||||
|
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||||
|
set_accessible(hypernetwork)
|
||||||
|
weights = hypernetwork.weights(True)
|
||||||
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
|
if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
|
if use_adamw_parameter:
|
||||||
|
if hypernetwork.optimizer_name != 'AdamW':
|
||||||
|
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
|
||||||
|
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
|
||||||
|
else:
|
||||||
|
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate)
|
||||||
|
optimizer_name = hypernetwork.optimizer_name
|
||||||
|
else:
|
||||||
|
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||||
|
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
|
||||||
|
optimizer_name = 'AdamW'
|
||||||
|
|
||||||
|
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||||
|
try:
|
||||||
|
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||||
|
optim_to(optimizer, devices.device)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print("Cannot resume from saved optimizer!")
|
||||||
|
print(e)
|
||||||
|
return hypernetwork, optimizer, weights, optimizer_name
|
||||||
|
|
||||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
||||||
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
||||||
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
|
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
|
||||||
|
|
@ -176,15 +222,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||||
template_file, steps, save_hypernetwork_every, create_image_every,
|
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||||
log_directory, name="hypernetwork")
|
log_directory, name="hypernetwork")
|
||||||
|
|
||||||
load_hypernetwork(hypernetwork_name)
|
|
||||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
|
||||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
|
||||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
|
||||||
|
|
||||||
shared.state.job = "train-hypernetwork"
|
shared.state.job = "train-hypernetwork"
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
shared.state.job_count = steps
|
shared.state.job_count = steps
|
||||||
|
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0)
|
||||||
|
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, **adamW_kwarg_dict)
|
||||||
|
|
||||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
@ -204,7 +246,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
initial_step = hypernetwork.step or 0
|
initial_step = hypernetwork.step or 0
|
||||||
|
|
@ -251,29 +292,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
weights = hypernetwork.weights(True)
|
|
||||||
|
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
|
||||||
if hypernetwork.optimizer_name in optimizer_dict:
|
|
||||||
if use_adamw_parameter:
|
|
||||||
if hypernetwork.optimizer_name != 'AdamW':
|
|
||||||
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!")
|
|
||||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
|
||||||
else:
|
|
||||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
|
||||||
optimizer_name = hypernetwork.optimizer_name
|
|
||||||
else:
|
|
||||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
|
||||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
|
||||||
optimizer_name = 'AdamW'
|
|
||||||
|
|
||||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
|
||||||
try:
|
|
||||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
|
||||||
optim_to(optimizer, devices.device)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("Cannot resume from saved optimizer!")
|
|
||||||
print(e)
|
|
||||||
if use_beta_scheduler:
|
if use_beta_scheduler:
|
||||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
||||||
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
||||||
|
|
@ -421,7 +439,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
)
|
)
|
||||||
|
|
||||||
if preview_from_txt2img:
|
if preview_from_txt2img:
|
||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt + hypernetwork.extra_name()
|
||||||
|
print(p.prompt)
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
|
@ -430,7 +449,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
|
@ -465,6 +484,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
forced_filename=forced_filename,
|
forced_filename=forced_filename,
|
||||||
save_to_dirs=False)
|
save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
set_accessible(hypernetwork)
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
|
@ -487,6 +507,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||||
sd_hijack_checkpoint.remove()
|
sd_hijack_checkpoint.remove()
|
||||||
set_scheduler(-1, False, False)
|
set_scheduler(-1, False, False)
|
||||||
|
remove_accessible()
|
||||||
report_statistics(loss_dict)
|
report_statistics(loss_dict)
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
|
@ -536,11 +557,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
dropout_structure, optional_info, weight_init_seed, normal_std,
|
dropout_structure, optional_info, weight_init_seed, normal_std,
|
||||||
skip_connection)
|
skip_connection)
|
||||||
else:
|
else:
|
||||||
load_hypernetwork(hypernetwork_name)
|
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
|
||||||
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
|
hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
load_hypernetwork(hypernetwork_name)
|
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||||
if load_training_options != '':
|
if load_training_options != '':
|
||||||
dump: dict = get_training_option(load_training_options)
|
dump: dict = get_training_option(load_training_options)
|
||||||
if dump and dump is not None:
|
if dump and dump is not None:
|
||||||
|
|
@ -660,11 +681,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||||
template_file, steps, save_hypernetwork_every, create_image_every,
|
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||||
log_directory, name="hypernetwork")
|
log_directory, name="hypernetwork")
|
||||||
load_hypernetwork(hypernetwork_name)
|
hypernetwork.to(devices.device)
|
||||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
if not isinstance(hypernetwork, Hypernetwork):
|
||||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||||
|
set_accessible(hypernetwork)
|
||||||
shared.state.job = "train-hypernetwork"
|
shared.state.job = "train-hypernetwork"
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
shared.state.job_count = steps
|
shared.state.job_count = steps
|
||||||
|
|
@ -687,7 +708,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
initial_step = hypernetwork.step or 0
|
initial_step = hypernetwork.step or 0
|
||||||
|
|
@ -711,7 +731,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
shared.sd_model.first_stage_model.requires_grad_(False)
|
shared.sd_model.first_stage_model.requires_grad_(False)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||||
height=training_height,
|
height=training_height,
|
||||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||||
|
|
@ -755,10 +774,10 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||||
try:
|
try:
|
||||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||||
optim_to(optimizer, devices.device)
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print("Cannot resume from saved optimizer!")
|
print("Cannot resume from saved optimizer!")
|
||||||
print(e)
|
print(e)
|
||||||
|
optim_to(optimizer, devices.device)
|
||||||
if use_beta_scheduler:
|
if use_beta_scheduler:
|
||||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
||||||
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
||||||
|
|
@ -895,7 +914,6 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
hypernetwork.eval()
|
hypernetwork.eval()
|
||||||
if move_optimizer:
|
if move_optimizer:
|
||||||
optim_to(optimizer, devices.cpu)
|
optim_to(optimizer, devices.cpu)
|
||||||
gc.collect()
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
|
@ -906,7 +924,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
)
|
)
|
||||||
|
|
||||||
if preview_from_txt2img:
|
if preview_from_txt2img:
|
||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt + hypernetwork.extra_name()
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
|
@ -915,7 +933,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = batch.cond_text[0]+ hypernetwork.extra_name()
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
|
@ -950,6 +968,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
||||||
forced_filename=forced_filename,
|
forced_filename=forced_filename,
|
||||||
save_to_dirs=False)
|
save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
set_accessible(hypernetwork)
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
|
@ -970,6 +989,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
hypernetwork.eval()
|
hypernetwork.eval()
|
||||||
set_scheduler(-1, False, False)
|
set_scheduler(-1, False, False)
|
||||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||||
|
remove_accessible()
|
||||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||||
sd_hijack_checkpoint.remove()
|
sd_hijack_checkpoint.remove()
|
||||||
if shared.opts.training_enable_tensorboard:
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import random
|
||||||
|
|
||||||
from modules import shared, sd_hijack, devices
|
from modules import shared, sd_hijack, devices
|
||||||
from modules.call_queue import wrap_gradio_call
|
from modules.call_queue import wrap_gradio_call
|
||||||
from modules.hypernetworks.ui import keys
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui import create_refresh_button, gr_show
|
from modules.ui import create_refresh_button, gr_show
|
||||||
from webui import wrap_gradio_gpu_call
|
from webui import wrap_gradio_gpu_call
|
||||||
|
|
@ -15,8 +14,11 @@ import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork_ui(*args):
|
def train_hypernetwork_ui(*args):
|
||||||
|
initial_hypernetwork = None
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
else:
|
||||||
|
shared.loaded_hypernetworks = []
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -32,14 +34,21 @@ Hypernetwork saved to {html.escape(filename)}
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
else:
|
||||||
|
shared.loaded_hypernetworks = []
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork_ui_tuning(*args):
|
def train_hypernetwork_ui_tuning(*args):
|
||||||
|
initial_hypernetwork = None
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
else:
|
||||||
|
shared.loaded_hypernetworks = []
|
||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
|
|
@ -55,7 +64,10 @@ Training {'interrupted' if shared.state.interrupted else 'finished'}.
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
else:
|
||||||
|
shared.loaded_hypernetworks = []
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import modules.shared
|
||||||
|
|
||||||
|
|
||||||
|
def find_self(self):
|
||||||
|
for k, v in modules.shared.hypernetworks.items():
|
||||||
|
if v == self:
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def optim_to(optim:torch.optim.Optimizer, device="cpu"):
|
def optim_to(optim:torch.optim.Optimizer, device="cpu"):
|
||||||
def inplace_move(obj: torch.Tensor, target):
|
def inplace_move(obj: torch.Tensor, target):
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,10 @@
|
||||||
import datetime
|
|
||||||
import glob
|
import glob
|
||||||
import html
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict, deque
|
|
||||||
from statistics import stdev, mean
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
|
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
|
||||||
|
|
||||||
import scripts.xy_grid
|
import scripts.xy_grid
|
||||||
|
|
@ -20,16 +15,10 @@ except (ImportError, ModuleNotFoundError):
|
||||||
print("modules.hashes is not found, will use backup module from extension!")
|
print("modules.hashes is not found, will use backup module from extension!")
|
||||||
from .hashes_backup import sha256
|
from .hashes_backup import sha256
|
||||||
|
|
||||||
from .scheduler import CosineAnnealingWarmUpRestarts
|
|
||||||
|
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
from modules import devices, shared, sd_models, processing, sd_samplers, generation_parameters_copypaste
|
from modules import devices, shared, sd_models, processing, generation_parameters_copypaste
|
||||||
from .hnutil import parse_dropout_structure, optim_to
|
from .hnutil import parse_dropout_structure, find_self
|
||||||
from modules.hypernetworks.hypernetwork import report_statistics, save_hypernetwork, stack_conds, optimizer_dict
|
from .shared import version_flag
|
||||||
from modules.textual_inversion import textual_inversion
|
|
||||||
from .dataset import PersonalizedBase
|
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
||||||
|
|
||||||
|
|
||||||
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
|
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
|
||||||
w, b = layer.weight.data, layer.bias.data
|
w, b = layer.weight.data, layer.bias.data
|
||||||
|
|
@ -217,10 +206,10 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
resnet_result = self.linear(x)
|
resnet_result = self.linear(x)
|
||||||
residual = resnet_result - x
|
residual = resnet_result - x
|
||||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||||
multiplier = HypernetworkModule.multiplier
|
multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
|
||||||
return x + multiplier * residual # interpolate
|
return x + multiplier * residual # interpolate
|
||||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
|
||||||
return x + self.linear(x) * multiplier
|
return x + self.linear(x) * multiplier
|
||||||
|
|
||||||
def trainables(self, train=False):
|
def trainables(self, train=False):
|
||||||
|
|
@ -317,6 +306,14 @@ class Hypernetwork:
|
||||||
sha256v = sha256(self.filename, f'hypernet/{self.name}')
|
sha256v = sha256(self.filename, f'hypernet/{self.name}')
|
||||||
return sha256v[0:10]
|
return sha256v[0:10]
|
||||||
|
|
||||||
|
def extra_name(self):
|
||||||
|
if version_flag:
|
||||||
|
return ""
|
||||||
|
found = find_self(self)
|
||||||
|
if found is not None:
|
||||||
|
return f" <hypernet:{found}:1.0>"
|
||||||
|
return f" <hypernet:{self.name}:1.0>"
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
optimizer_saved_dict = {}
|
optimizer_saved_dict = {}
|
||||||
|
|
@ -412,9 +409,18 @@ class Hypernetwork:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for values in self.layers.values():
|
for k, layers in self.layers.items():
|
||||||
values[0].to(device)
|
for layer in layers:
|
||||||
values[1].to(device)
|
layer.to(device)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def __call__(self, context, *args, **kwargs):
|
def __call__(self, context, *args, **kwargs):
|
||||||
return self.forward(context, *args, **kwargs)
|
return self.forward(context, *args, **kwargs)
|
||||||
|
|
@ -436,9 +442,13 @@ def list_hypernetworks(path):
|
||||||
res = {}
|
res = {}
|
||||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
idx = 0
|
||||||
|
while name in res:
|
||||||
|
idx += 1
|
||||||
|
name = name + f"({idx})"
|
||||||
# Prevent a hypothetical "None.pt" from being listed.
|
# Prevent a hypothetical "None.pt" from being listed.
|
||||||
if name != "None":
|
if name != "None":
|
||||||
res[name+ f"({sd_models.model_hash(filename)})"] = filename
|
res[name] = filename
|
||||||
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
|
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
if name != "None":
|
if name != "None":
|
||||||
|
|
@ -451,7 +461,10 @@ def find_closest_first(keyset, target):
|
||||||
return keys
|
return keys
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork(filename):
|
def load_hypernetwork(filename):
|
||||||
|
hypernetwork = None
|
||||||
path = shared.hypernetworks.get(filename, None)
|
path = shared.hypernetworks.get(filename, None)
|
||||||
if path is None:
|
if path is None:
|
||||||
filename = find_closest_first(shared.hypernetworks.keys(), filename)
|
filename = find_closest_first(shared.hypernetworks.keys(), filename)
|
||||||
|
|
@ -462,8 +475,12 @@ def load_hypernetwork(filename):
|
||||||
print(f"Loading hypernetwork {filename}")
|
print(f"Loading hypernetwork {filename}")
|
||||||
if path.endswith(".pt"):
|
if path.endswith(".pt"):
|
||||||
try:
|
try:
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
hypernetwork = Hypernetwork()
|
||||||
shared.loaded_hypernetwork.load(path)
|
hypernetwork.load(path)
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
|
shared.loaded_hypernetwork = hypernetwork
|
||||||
|
else:
|
||||||
|
return hypernetwork
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||||
|
|
@ -472,18 +489,23 @@ def load_hypernetwork(filename):
|
||||||
# Load Hypernetwork processing
|
# Load Hypernetwork processing
|
||||||
try:
|
try:
|
||||||
from .hypernetworks import load as load_hns
|
from .hypernetworks import load as load_hns
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
shared.loaded_hypernetwork = load_hns(path)
|
shared.loaded_hypernetwork = load_hns(path)
|
||||||
|
else:
|
||||||
|
hypernetwork = load_hns(path)
|
||||||
print(f"Loaded Hypernetwork Structure {path}")
|
print(f"Loaded Hypernetwork Structure {path}")
|
||||||
|
return hypernetwork
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
|
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
print(f"Tried to load unknown file extension: {filename}")
|
print(f"Tried to load unknown file extension: {filename}")
|
||||||
else:
|
else:
|
||||||
|
if hasattr(shared, 'loaded_hypernetwork'):
|
||||||
if shared.loaded_hypernetwork is not None:
|
if shared.loaded_hypernetwork is not None:
|
||||||
print(f"Unloading hypernetwork")
|
print(f"Unloading hypernetwork")
|
||||||
|
|
||||||
shared.loaded_hypernetwork = None
|
shared.loaded_hypernetwork = None
|
||||||
|
return hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||||
|
|
@ -504,266 +526,32 @@ def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||||
return context_k, context_v
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||||
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
|
if hypernetwork is None:
|
||||||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
return context_k, context_v
|
||||||
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
if isinstance(hypernetwork, Hypernetwork):
|
||||||
use_beta_scheduler=False, beta_repeat_epoch=4000,epoch_mult=1, warmup =10, min_lr=1e-7, gamma_rate=1):
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
if hypernetwork_layers is None:
|
||||||
from modules import images
|
return context_k, context_v
|
||||||
try:
|
if layer is not None:
|
||||||
if use_beta_scheduler:
|
layer.hyper_k = hypernetwork_layers[0]
|
||||||
print("Using Beta Scheduler")
|
layer.hyper_v = hypernetwork_layers[1]
|
||||||
beta_repeat_epoch = int(beta_repeat_epoch)
|
|
||||||
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
|
|
||||||
min_lr = float(min_lr)
|
|
||||||
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
|
|
||||||
gamma_rate = float(gamma_rate)
|
|
||||||
print(f"Using learn rate decay(per cycle) of {gamma_rate}")
|
|
||||||
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
|
|
||||||
epoch_mult = int(float(epoch_mult))
|
|
||||||
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!"
|
|
||||||
warmup = int(warmup)
|
|
||||||
assert warmup >= 1, "Warmup epoch should be larger than 0!"
|
|
||||||
else:
|
|
||||||
beta_repeat_epoch = 4000
|
|
||||||
epoch_mult=1
|
|
||||||
warmup=10
|
|
||||||
min_lr=1e-7
|
|
||||||
gamma_rate=1
|
|
||||||
except ValueError:
|
|
||||||
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
|
||||||
create_image_every = create_image_every or 0
|
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, 1, template_file, steps,
|
|
||||||
save_hypernetwork_every, create_image_every, log_directory,
|
|
||||||
name="hypernetwork")
|
|
||||||
|
|
||||||
load_hypernetwork(hypernetwork_name)
|
context_k = hypernetwork_layers[0](context_k)
|
||||||
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
context_v = hypernetwork_layers[1](context_v)
|
||||||
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
return context_k, context_v
|
||||||
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
context_k, context_v = hypernetwork(context_k, context_v, layer=layer)
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
return context_k, context_v
|
||||||
shared.state.job_count = steps
|
|
||||||
losses_list = []
|
|
||||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
|
||||||
unload = shared.opts.unload_models_when_training
|
|
||||||
|
|
||||||
if save_hypernetwork_every > 0:
|
|
||||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
|
||||||
os.makedirs(hypernetwork_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
hypernetwork_dir = None
|
|
||||||
|
|
||||||
if create_image_every > 0:
|
|
||||||
images_dir = os.path.join(log_directory, "images")
|
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
images_dir = None
|
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
|
||||||
|
|
||||||
ititial_step = hypernetwork.step or 0
|
|
||||||
if ititial_step >= steps:
|
|
||||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
|
||||||
return hypernetwork, filename
|
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
||||||
with torch.autocast("cuda"):
|
|
||||||
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
|
||||||
height=training_height,
|
|
||||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
|
||||||
placeholder_token=hypernetwork_name,
|
|
||||||
model=shared.sd_model, device=devices.device,
|
|
||||||
template_file=template_file, include_cond=True,
|
|
||||||
batch_size=batch_size)
|
|
||||||
|
|
||||||
if unload:
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
size = len(ds.indexes)
|
|
||||||
loss_dict = defaultdict(lambda: deque(maxlen=1024))
|
|
||||||
losses = torch.zeros((size,))
|
|
||||||
previous_mean_losses = [0]
|
|
||||||
previous_mean_loss = 0
|
|
||||||
print("Mean loss of {} elements".format(size))
|
|
||||||
|
|
||||||
weights = hypernetwork.weights(True)
|
|
||||||
|
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
|
||||||
if hypernetwork.optimizer_name in optimizer_dict:
|
|
||||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
|
||||||
optimizer_name = hypernetwork.optimizer_name
|
|
||||||
else:
|
|
||||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
|
||||||
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
|
||||||
optimizer_name = 'AdamW'
|
|
||||||
|
|
||||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
|
||||||
try:
|
|
||||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("Cannot resume from saved optimizer!")
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
|
|
||||||
scheduler_beta.last_epoch =hypernetwork.step-1
|
|
||||||
steps_without_grad = 0
|
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
|
||||||
last_saved_image = "<none>"
|
|
||||||
forced_filename = "<none>"
|
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
|
||||||
for i, entries in pbar:
|
|
||||||
hypernetwork.step = i + ititial_step
|
|
||||||
if use_beta_scheduler:
|
|
||||||
scheduler_beta.step(hypernetwork.step)
|
|
||||||
if len(loss_dict) > 0:
|
|
||||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
|
||||||
previous_mean_loss = mean(previous_mean_losses)
|
|
||||||
if not use_beta_scheduler:
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
|
||||||
if i + ititial_step > steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
if shared.state.interrupted:
|
|
||||||
break
|
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
|
||||||
loss_infos = shared.sd_model(x, c)[1]
|
|
||||||
loss = loss_infos[
|
|
||||||
'val/loss_simple'] # + loss_infos['val/loss_vlb'] * 0.4 #its 'prior class preserving' loss
|
|
||||||
del x
|
|
||||||
del c
|
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
|
||||||
losses_list.append(loss.item())
|
|
||||||
for entry in entries:
|
|
||||||
loss_dict[entry.filename].append(loss.item())
|
|
||||||
optimizer.zero_grad()
|
|
||||||
weights[0].grad = None
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if weights[0].grad is None:
|
|
||||||
steps_without_grad += 1
|
|
||||||
else:
|
|
||||||
steps_without_grad = 0
|
|
||||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
|
||||||
raise RuntimeError("Loss diverged.")
|
|
||||||
|
|
||||||
if len(previous_mean_losses) > 1:
|
|
||||||
std = stdev(previous_mean_losses)
|
|
||||||
else:
|
|
||||||
std = 0
|
|
||||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
|
||||||
pbar.set_description(dataset_loss_info)
|
|
||||||
|
|
||||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
|
||||||
# Before saving, change name to match current checkpoint.
|
|
||||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
|
||||||
if shared.opts.save_optimizer_state:
|
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
|
||||||
"loss": f"{previous_mean_loss:.7f}",
|
|
||||||
"learn_rate": optimizer.param_groups[0]['lr']
|
|
||||||
})
|
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
|
||||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
optim_to(optimizer, devices.cpu)
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
|
||||||
sd_model=shared.sd_model,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
|
||||||
p.prompt = preview_prompt
|
|
||||||
p.negative_prompt = preview_negative_prompt
|
|
||||||
p.steps = preview_steps
|
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
|
||||||
|
|
||||||
if unload:
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
shared.state.current_image = image
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
|
||||||
shared.opts.samples_format, processed.infotexts[0],
|
|
||||||
p=p, forced_filename=forced_filename,
|
|
||||||
save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
optim_to(optimizer, devices.device)
|
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
|
||||||
<p>
|
|
||||||
Loss: {previous_mean_loss:.7f}<br/>
|
|
||||||
Step: {hypernetwork.step}<br/>
|
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
|
||||||
</p>
|
|
||||||
"""
|
|
||||||
|
|
||||||
report_statistics(loss_dict)
|
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
|
||||||
if shared.opts.save_optimizer_state:
|
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
|
||||||
del optimizer
|
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
||||||
hypernetwork.eval()
|
|
||||||
return hypernetwork, filename
|
|
||||||
|
|
||||||
def apply_strength(value=None):
|
def apply_strength(value=None):
|
||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork_strength(p, x, xs):
|
def apply_hypernetwork_strength(p, x, xs):
|
||||||
apply_strength(x)
|
apply_strength(x)
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
|
|
@ -778,9 +566,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'name') else shared.loaded_hypernetwork.name),
|
|
||||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None or not hasattr(shared.loaded_hypernetwork, 'filename') else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
|
||||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
|
|
@ -801,13 +586,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
|
|
||||||
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
|
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
|
||||||
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
|
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
|
||||||
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
|
if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'):
|
||||||
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
|
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
|
||||||
|
else:
|
||||||
|
modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork
|
||||||
|
if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'):
|
||||||
|
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
|
||||||
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
|
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
|
||||||
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
|
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
|
||||||
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
|
if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'):
|
||||||
|
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
|
||||||
|
|
||||||
# Fix calculating hash for multiple hns
|
# Fix calculating hash for multiple hns
|
||||||
processing.create_infotext = create_infotext
|
processing.create_infotext = create_infotext
|
||||||
|
|
@ -4,6 +4,8 @@ import os.path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import devices, shared
|
from modules import devices, shared
|
||||||
|
from .hnutil import find_self
|
||||||
|
from .shared import version_flag
|
||||||
|
|
||||||
lazy_load = False # when this is enabled, HNs will be loaded when required.
|
lazy_load = False # when this is enabled, HNs will be loaded when required.
|
||||||
|
|
||||||
|
|
@ -79,6 +81,17 @@ class Forward:
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_multiplier(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def extra_name(self):
|
||||||
|
if version_flag:
|
||||||
|
return ""
|
||||||
|
found = find_self(self)
|
||||||
|
if found is not None:
|
||||||
|
return f" <hypernet:{found}:1.0>"
|
||||||
|
return f" <hypernet:{self.name}:1.0>"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(arg, name=None):
|
def parse(arg, name=None):
|
||||||
arg = Forward.unpack(arg)
|
arg = Forward.unpack(arg)
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
import modules.shared
|
import modules.shared
|
||||||
|
|
||||||
|
version_flag = hasattr(modules.shared, 'loaded_hypernetwork')
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
from .hypernetwork import list_hypernetworks, load_hypernetwork
|
from .hypernetwork import list_hypernetworks, load_hypernetwork
|
||||||
modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir)
|
modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
|
if hasattr(modules.shared, 'loaded_hypernetwork'):
|
||||||
load_hypernetwork(opts.sd_hypernetwork)
|
load_hypernetwork(opts.sd_hypernetwork)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import html
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules import shared, sd_hijack, devices
|
from modules import shared
|
||||||
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
|
from .hypernetwork import Hypernetwork, load_hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||||
weight_init_seed=None, normal_std=0.01, skip_connection=False):
|
weight_init_seed=None, normal_std=0.01, skip_connection=False):
|
||||||
|
|
@ -36,8 +36,8 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
|
||||||
)
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
load_hypernetwork(fn)
|
hypernet = load_hypernetwork(name)
|
||||||
|
assert hypernet is not None, f"Cannot load from {name}!"
|
||||||
return hypernet
|
return hypernet
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -76,27 +76,3 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
return name, f"Created: {fn}", ""
|
return name, f"Created: {fn}", ""
|
||||||
|
|
||||||
def train_hypernetwork_ui(*args):
|
|
||||||
|
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
|
||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
|
||||||
|
|
||||||
try:
|
|
||||||
sd_hijack.undo_optimizations()
|
|
||||||
|
|
||||||
hypernetwork, filename = train_hypernetwork(*args)
|
|
||||||
|
|
||||||
res = f"""
|
|
||||||
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
|
||||||
Hypernetwork saved to {html.escape(filename)}
|
|
||||||
"""
|
|
||||||
return res, ""
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
sd_hijack.apply_optimizations()
|
|
||||||
|
|
@ -17,95 +17,6 @@ from webui import wrap_gradio_gpu_call
|
||||||
|
|
||||||
setattr(shared.opts,'pin_memory', False)
|
setattr(shared.opts,'pin_memory', False)
|
||||||
|
|
||||||
|
|
||||||
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
|
||||||
with gr.Tab(label="Train_Beta") as train_beta:
|
|
||||||
gr.HTML(
|
|
||||||
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
|
||||||
with gr.Row():
|
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
|
|
||||||
choices=[x for x in shared.hypernetworks.keys()])
|
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
|
||||||
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
|
||||||
"refresh_train_hypernetwork_name")
|
|
||||||
with gr.Row():
|
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
|
|
||||||
placeholder="Hypernetwork Learning rate", value="0.00001")
|
|
||||||
use_beta_scheduler_checkbox = gr.Checkbox(
|
|
||||||
label='Show advanced learn rate scheduler options(for Hypernetworks)')
|
|
||||||
with gr.Row(visible=False) as beta_scheduler_options:
|
|
||||||
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
|
|
||||||
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
|
|
||||||
epoch_mult = gr.Textbox(label='Epoch multiplier per cycle', placeholder="Cycles length multiplier every cycle", value="1")
|
|
||||||
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="1")
|
|
||||||
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
|
|
||||||
placeholder="restricts decay value, but does not restrict gamma rate decay",
|
|
||||||
value="1e-7")
|
|
||||||
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
|
|
||||||
placeholder="Value should be in (0-1]", value="1")
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
|
||||||
value="textual_inversion")
|
|
||||||
template_file = gr.Textbox(label='Prompt template file',
|
|
||||||
value=os.path.join(script_path, "textual_inversion_templates",
|
|
||||||
"style_filewords.txt"))
|
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500,
|
|
||||||
precision=0)
|
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable',
|
|
||||||
value=500, precision=0)
|
|
||||||
preview_from_txt2img = gr.Checkbox(
|
|
||||||
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
interrupt_training = gr.Button(value="Interrupt")
|
|
||||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
|
||||||
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
|
||||||
use_beta_scheduler_checkbox.change(
|
|
||||||
fn=lambda show: gr_show(show),
|
|
||||||
inputs=[use_beta_scheduler_checkbox],
|
|
||||||
outputs=[beta_scheduler_options],
|
|
||||||
)
|
|
||||||
interrupt_training.click(
|
|
||||||
fn=lambda: shared.state.interrupt(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
train_hypernetwork.click(
|
|
||||||
fn=wrap_gradio_gpu_call(ui.train_hypernetwork_ui, extra_outputs=[gr.update()]),
|
|
||||||
_js="start_training_textual_inversion",
|
|
||||||
inputs=[
|
|
||||||
train_hypernetwork_name,
|
|
||||||
hypernetwork_learn_rate,
|
|
||||||
batch_size,
|
|
||||||
dataset_directory,
|
|
||||||
log_directory,
|
|
||||||
training_width,
|
|
||||||
training_height,
|
|
||||||
steps,
|
|
||||||
create_image_every,
|
|
||||||
save_embedding_every,
|
|
||||||
template_file,
|
|
||||||
preview_from_txt2img,
|
|
||||||
*params.txt2img_preview_params,
|
|
||||||
use_beta_scheduler,
|
|
||||||
beta_repeat_epoch,
|
|
||||||
epoch_mult,
|
|
||||||
warmup,
|
|
||||||
min_lr,
|
|
||||||
gamma_rate
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
ti_output,
|
|
||||||
ti_outcome,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return [(train_beta, "Train_beta", "train_beta")]
|
|
||||||
|
|
||||||
def create_extension_tab(params=None):
|
def create_extension_tab(params=None):
|
||||||
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue