commit
bb875af507
|
|
@ -11,7 +11,6 @@ from torch.utils.data import Dataset, DataLoader, Sampler
|
|||
from torchvision import transforms
|
||||
|
||||
from ..hnutil import get_closest
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from random import Random
|
||||
import tqdm
|
||||
|
|
@ -26,6 +25,7 @@ random_state_manager = Random(None)
|
|||
shuffle = random_state_manager.shuffle
|
||||
choice = random_state_manager.choice
|
||||
choices = random_state_manager.choices
|
||||
randrange = random_state_manager.randrange
|
||||
|
||||
|
||||
def set_rng(seed=None):
|
||||
|
|
@ -47,10 +47,10 @@ class DatasetEntry:
|
|||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None,
|
||||
cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1,
|
||||
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(
|
||||
shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
seed = random.randrange(sys.maxsize)
|
||||
seed = randrange(sys.maxsize)
|
||||
set_rng(seed) # reset forked RNG state when we create dataset.
|
||||
print(f"Dataset seed was set to f{seed}")
|
||||
self.placeholder_token = placeholder_token
|
||||
|
|
@ -128,6 +128,10 @@ class PersonalizedBase(Dataset):
|
|||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "random":
|
||||
if latent_sampling_std != -1:
|
||||
assert latent_sampling_std > 0, f"Cannnot apply negative standard deviation {latent_sampling_std}"
|
||||
print(f"Applying patch, clipping std from {torch.max(latent_dist.std).item()} to {latent_sampling_std}...")
|
||||
latent_dist.std.clip_(latent_sampling_std)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
|
|
@ -154,7 +158,7 @@ class PersonalizedBase(Dataset):
|
|||
text = text.replace("[name]", self.placeholder_token)
|
||||
tags = filename_text.split(',')
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
tags = [t for t in tags if random_state_manager.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,
|
||||
adamw_eps=1e-8,
|
||||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
|
||||
optional_gradient_norm_type=2,
|
||||
optional_gradient_norm_type=2, latent_sampling_std=-1,
|
||||
load_training_options=''):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
|
@ -86,6 +86,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
gradient_clip_opt = dump['gradient_clip_opt']
|
||||
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||
latent_sampling_std = dump.get('latent_sampling_std', -1)
|
||||
try:
|
||||
if use_adamw_parameter:
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
|
||||
|
|
@ -223,7 +224,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
include_cond=True, batch_size=batch_size,
|
||||
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
||||
tag_drop_out=tag_drop_out,
|
||||
latent_sampling_method=latent_sampling_method)
|
||||
latent_sampling_method=latent_sampling_method,
|
||||
latent_sampling_std=latent_sampling_std)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
|
|
@ -502,9 +504,10 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
optional_info = dump_hyper['optional_info']
|
||||
weight_init_seed = dump_hyper['weight_init_seed']
|
||||
normal_std = dump_hyper['normal_std']
|
||||
skip_connection = dump_hyper['skip_connection']
|
||||
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure,
|
||||
activation_func, weight_init, add_layer_norm, use_dropout,
|
||||
dropout_structure, optional_info, weight_init_seed, normal_std)
|
||||
dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection)
|
||||
else:
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(',1)[0] + time.strftime('%Y%m%d%H%M%S')
|
||||
|
|
@ -541,6 +544,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
gradient_clip_opt = dump['gradient_clip_opt']
|
||||
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||
latent_sampling_std = dump.get('latent_sampling_std', -1)
|
||||
else:
|
||||
raise RuntimeError(f"Cannot load from {load_training_options}!")
|
||||
else:
|
||||
|
|
@ -682,7 +686,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
include_cond=True, batch_size=batch_size,
|
||||
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
||||
tag_drop_out=tag_drop_out,
|
||||
latent_sampling_method=latent_sampling_method)
|
||||
latent_sampling_method=latent_sampling_method,
|
||||
latent_sampling_std=latent_sampling_std)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@ from modules.textual_inversion.textual_inversion import save_embedding
|
|||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from ..tbutils import tensorboard_add, tensorboard_setup, tensorboard_add_scaler, tensorboard_add_image
|
||||
#apply OsError avoid here
|
||||
|
||||
# apply OsError avoid here
|
||||
delayed_values = {}
|
||||
|
||||
|
||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
|
@ -55,9 +57,9 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
**values,
|
||||
})
|
||||
except OSError:
|
||||
epoch, epoch_step = divmod(step-1, epoch_len)
|
||||
epoch, epoch_step = divmod(step - 1, epoch_len)
|
||||
if log_directory + filename in delayed_values:
|
||||
delayed_values[log_directory + filename].append((step , epoch, epoch_step, values))
|
||||
delayed_values[log_directory + filename].append((step, epoch, epoch_step, values))
|
||||
else:
|
||||
delayed_values[log_directory + filename] = [(step, epoch, epoch_step, values)]
|
||||
|
||||
|
|
@ -86,15 +88,19 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width,
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
||||
training_width,
|
||||
training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every,
|
||||
save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img,
|
||||
preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale,
|
||||
preview_seed, preview_width, preview_height,
|
||||
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1,warmup =10, min_lr=1e-7, gamma_rate=1, save_when_converge=False, create_when_converge=False,
|
||||
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1, warmup=10, min_lr=1e-7,
|
||||
gamma_rate=1, save_when_converge=False, create_when_converge=False,
|
||||
move_optimizer=True,
|
||||
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,adamw_eps=1e-8,
|
||||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2
|
||||
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99,
|
||||
adamw_eps=1e-8,
|
||||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
|
||||
optional_gradient_norm_type=2, latent_sampling_std=-1
|
||||
):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
|
@ -102,20 +108,23 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
try:
|
||||
if use_adamw_parameter:
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in [adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps]]
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
|
||||
[adamw_weight_decay, adamw_beta_1,
|
||||
adamw_beta_2, adamw_eps]]
|
||||
assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!"
|
||||
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, adamw_eps])), "Cannot use negative or >1 number for adamW parameters!"
|
||||
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2,
|
||||
adamw_eps])), "Cannot use negative or >1 number for adamW parameters!"
|
||||
adamW_kwarg_dict = {
|
||||
'weight_decay' : adamw_weight_decay,
|
||||
'betas' : (adamw_beta_1, adamw_beta_2),
|
||||
'eps' : adamw_eps
|
||||
'weight_decay': adamw_weight_decay,
|
||||
'betas': (adamw_beta_1, adamw_beta_2),
|
||||
'eps': adamw_eps
|
||||
}
|
||||
print('Using custom AdamW parameters')
|
||||
else:
|
||||
adamW_kwarg_dict = {
|
||||
'weight_decay' : 0.01,
|
||||
'betas' : (0.9, 0.99),
|
||||
'eps' : 1e-8
|
||||
'weight_decay': 0.01,
|
||||
'betas': (0.9, 0.99),
|
||||
'eps': 1e-8
|
||||
}
|
||||
if use_beta_scheduler:
|
||||
print("Using Beta Scheduler")
|
||||
|
|
@ -134,10 +143,10 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
print(f"Generate image when converges : {create_when_converge}")
|
||||
else:
|
||||
beta_repeat_epoch = 4000
|
||||
epoch_mult=1
|
||||
warmup=10
|
||||
min_lr=1e-7
|
||||
gamma_rate=1
|
||||
epoch_mult = 1
|
||||
warmup = 10
|
||||
min_lr = 1e-7
|
||||
gamma_rate = 1
|
||||
save_when_converge = False
|
||||
create_when_converge = False
|
||||
except ValueError:
|
||||
|
|
@ -153,6 +162,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
except ValueError:
|
||||
raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})")
|
||||
assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}"
|
||||
|
||||
def gradient_clipping(arg1):
|
||||
torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type)
|
||||
return
|
||||
|
|
@ -212,38 +222,40 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
detach_grad = shared.opts.disable_ema # test code that removes EMA
|
||||
detach_grad = shared.opts.disable_ema # test code that removes EMA
|
||||
if detach_grad:
|
||||
print("Disabling training for staged models!")
|
||||
shared.sd_model.cond_stage_model.requires_grad_(False)
|
||||
shared.sd_model.first_stage_model.requires_grad_(False)
|
||||
torch.cuda.empty_cache()
|
||||
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||
height=training_height,
|
||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||
placeholder_token=embedding_name, model=shared.sd_model,
|
||||
cond_model=shared.sd_model.cond_stage_model,
|
||||
device=devices.device, template_file=template_file,
|
||||
batch_size=batch_size, gradient_step=gradient_step,
|
||||
shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out,
|
||||
latent_sampling_method=latent_sampling_method)
|
||||
height=training_height,
|
||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||
placeholder_token=embedding_name, model=shared.sd_model,
|
||||
cond_model=shared.sd_model.cond_stage_model,
|
||||
device=devices.device, template_file=template_file,
|
||||
batch_size=batch_size, gradient_step=gradient_step,
|
||||
shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out,
|
||||
latent_sampling_method=latent_sampling_method,
|
||||
latent_sampling_std=latent_sampling_std)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
|
||||
batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad_(True)
|
||||
optimizer_name = 'AdamW' # hardcoded optimizer name now
|
||||
optimizer_name = 'AdamW' # hardcoded optimizer name now
|
||||
if use_adamw_parameter:
|
||||
optimizer = torch.optim.AdamW(params=[embedding.vec], lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||
else:
|
||||
optimizer = torch.optim.AdamW(params=[embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
|
||||
if os.path.exists(filename + '.optim'): # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
if os.path.exists(
|
||||
filename + '.optim'): # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
|
|
@ -259,8 +271,10 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
if move_optimizer:
|
||||
optim_to(optimizer, devices.device)
|
||||
if use_beta_scheduler:
|
||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
|
||||
scheduler_beta.last_epoch = embedding.step-1
|
||||
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch,
|
||||
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate,
|
||||
warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate)
|
||||
scheduler_beta.last_epoch = embedding.step - 1
|
||||
else:
|
||||
scheduler_beta = None
|
||||
for pg in optimizer.param_groups:
|
||||
|
|
@ -310,7 +324,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
if is_training_inpainting_model:
|
||||
if img_c is None:
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width,
|
||||
training_height)
|
||||
|
||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||
else:
|
||||
|
|
@ -340,7 +355,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
epoch_step = embedding.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if embedding_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and save_when_converge) or (save_embedding_every > 0 and steps_done % save_embedding_every == 0)):
|
||||
if embedding_dir is not None and (
|
||||
(use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and save_when_converge) or (
|
||||
save_embedding_every > 0 and steps_done % save_embedding_every == 0)):
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
|
|
@ -355,7 +372,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if images_dir is not None and ((use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and create_when_converge) or (create_image_every > 0 and steps_done % create_image_every == 0)):
|
||||
if images_dir is not None and (
|
||||
(use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and create_when_converge) or (
|
||||
create_image_every > 0 and steps_done % create_image_every == 0)):
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
rng_state = torch.get_rng_state()
|
||||
|
|
@ -429,7 +448,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.shorthash if hasattr(checkpoint, 'shorthash') else checkpoint.hash)
|
||||
footer_mid = '[{}]'.format(
|
||||
checkpoint.shorthash if hasattr(checkpoint, 'shorthash') else checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
|
|
@ -449,7 +469,6 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
|
|
@ -471,4 +490,4 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
return embedding, filename
|
||||
return embedding, filename
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import gradio as gr
|
|||
|
||||
|
||||
def train_hypernetwork_ui(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
|
@ -40,7 +39,6 @@ Hypernetwork saved to {html.escape(filename)}
|
|||
|
||||
|
||||
def train_hypernetwork_ui_tuning(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
|
@ -69,7 +67,7 @@ def save_training_setting(*args):
|
|||
template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \
|
||||
gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
|
||||
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type = args
|
||||
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std = args
|
||||
dumped_locals = locals()
|
||||
dumped_locals.pop('args')
|
||||
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
|
||||
|
|
@ -82,7 +80,7 @@ def save_training_setting(*args):
|
|||
|
||||
|
||||
def save_hypernetwork_setting(*args):
|
||||
save_file_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std = args
|
||||
save_file_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection = args
|
||||
dumped_locals = locals()
|
||||
dumped_locals.pop('args')
|
||||
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_hypernetwork_' + '.json'
|
||||
|
|
@ -93,6 +91,7 @@ def save_hypernetwork_setting(*args):
|
|||
print(f"File saved as {filename}")
|
||||
return filename, ""
|
||||
|
||||
|
||||
def on_train_gamma_tab(params=None):
|
||||
dummy_component = gr.Label(visible=False)
|
||||
with gr.Tab(label="Train_Gamma") as train_gamma:
|
||||
|
|
@ -123,15 +122,18 @@ def on_train_gamma_tab(params=None):
|
|||
show_gradient_clip_checkbox = gr.Checkbox(
|
||||
label='Show Gradient Clipping Options(for both)')
|
||||
with gr.Row(visible=False) as adamW_options:
|
||||
adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01", value="0.01")
|
||||
adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01",
|
||||
value="0.01")
|
||||
adamw_beta_1 = gr.Textbox(label="AdamW beta1 parameter", placeholder="default = 0.9", value="0.9")
|
||||
adamw_beta_2 = gr.Textbox(label="AdamW beta2 parameter", placeholder="default = 0.99", value="0.99")
|
||||
adamw_eps = gr.Textbox(label="AdamW epsilon parameter", placeholder="default = 1e-8", value="1e-8")
|
||||
with gr.Row(visible=False) as beta_scheduler_options:
|
||||
use_beta_scheduler = gr.Checkbox(label='Use CosineAnnealingWarmupRestarts Scheduler')
|
||||
beta_repeat_epoch = gr.Textbox(label='Steps for cycle', placeholder="Cycles every nth Step", value="64")
|
||||
epoch_mult = gr.Textbox(label='Step multiplier per cycle', placeholder="Step length multiplier every cycle", value="1")
|
||||
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="5")
|
||||
epoch_mult = gr.Textbox(label='Step multiplier per cycle', placeholder="Step length multiplier every cycle",
|
||||
value="1")
|
||||
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step",
|
||||
value="5")
|
||||
min_lr = gr.Textbox(label='Minimum learning rate',
|
||||
placeholder="restricts decay value, but does not restrict gamma rate decay",
|
||||
value="6e-7")
|
||||
|
|
@ -144,7 +146,7 @@ def on_train_gamma_tab(params=None):
|
|||
gradient_clip_opt = gr.Radio(label="Gradient Clipping Options", choices=["None", "limit", "norm"])
|
||||
optional_gradient_clip_value = gr.Textbox(label="Limiting value", value="1e-1")
|
||||
optional_gradient_norm_type = gr.Textbox(label="Norm type", value="2")
|
||||
#change by feedback
|
||||
# change by feedback
|
||||
use_beta_adamW_checkbox.change(
|
||||
fn=lambda show: gr_show(show),
|
||||
inputs=[use_beta_adamW_checkbox],
|
||||
|
|
@ -165,7 +167,8 @@ def on_train_gamma_tab(params=None):
|
|||
inputs=[show_gradient_clip_checkbox],
|
||||
outputs=[gradient_clip_options],
|
||||
)
|
||||
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", value=True)
|
||||
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)",
|
||||
value=True)
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
|
|
@ -191,10 +194,12 @@ def on_train_gamma_tab(params=None):
|
|||
with gr.Row():
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once",
|
||||
choices=['once', 'deterministic', 'random'])
|
||||
latent_sampling_std_value = gr.Number(label="Standard deviation for sampling", value=-1)
|
||||
with gr.Row():
|
||||
save_training_option = gr.Button(value="Save training setting")
|
||||
save_file_name = gr.Textbox(label="File name to save setting as", value="")
|
||||
load_training_option = gr.Textbox(label="Load training option from saved json file. This will override settings above", value="")
|
||||
load_training_option = gr.Textbox(
|
||||
label="Load training option from saved json file. This will override settings above", value="")
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
|
|
@ -202,9 +207,9 @@ def on_train_gamma_tab(params=None):
|
|||
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
|
||||
ti_outcome = gr.HTML(elem_id="ti_error3", value="")
|
||||
|
||||
#Full path to .json or simple names are recommended.
|
||||
# Full path to .json or simple names are recommended.
|
||||
save_training_option.click(
|
||||
fn = wrap_gradio_call(save_training_setting),
|
||||
fn=wrap_gradio_call(save_training_setting),
|
||||
inputs=[
|
||||
save_file_name,
|
||||
hypernetwork_learn_rate,
|
||||
|
|
@ -233,7 +238,8 @@ def on_train_gamma_tab(params=None):
|
|||
show_gradient_clip_checkbox,
|
||||
gradient_clip_opt,
|
||||
optional_gradient_clip_value,
|
||||
optional_gradient_norm_type],
|
||||
optional_gradient_norm_type,
|
||||
latent_sampling_std_value],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
|
|
@ -279,7 +285,8 @@ def on_train_gamma_tab(params=None):
|
|||
show_gradient_clip_checkbox,
|
||||
gradient_clip_opt,
|
||||
optional_gradient_clip_value,
|
||||
optional_gradient_norm_type
|
||||
optional_gradient_norm_type,
|
||||
latent_sampling_std_value
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
|
|
@ -327,6 +334,7 @@ def on_train_gamma_tab(params=None):
|
|||
gradient_clip_opt,
|
||||
optional_gradient_clip_value,
|
||||
optional_gradient_norm_type,
|
||||
latent_sampling_std_value,
|
||||
load_training_option
|
||||
|
||||
],
|
||||
|
|
@ -343,6 +351,7 @@ def on_train_gamma_tab(params=None):
|
|||
)
|
||||
return [(train_gamma, "Train Gamma", "train_gamma")]
|
||||
|
||||
|
||||
def on_train_tuning(params=None):
|
||||
dummy_component = gr.Label(visible=False)
|
||||
with gr.Tab(label="Train_Tuning") as train_tuning:
|
||||
|
|
@ -354,14 +363,18 @@ def on_train_tuning(params=None):
|
|||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
||||
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
||||
"refresh_train_hypernetwork_name")
|
||||
optional_new_hypernetwork_name = gr.Textbox(label="Hypernetwork name to create, leave it empty to use selected", value="")
|
||||
optional_new_hypernetwork_name = gr.Textbox(
|
||||
label="Hypernetwork name to create, leave it empty to use selected", value="")
|
||||
with gr.Row():
|
||||
load_hypernetworks_option = gr.Textbox(
|
||||
label="Load Hypernetwork creation option from saved json file", placeholder = ". filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||
label="Load Hypernetwork creation option from saved json file",
|
||||
placeholder=". filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||
with gr.Row():
|
||||
load_training_options = gr.Textbox(
|
||||
label="Load training option(s) from saved json file", placeholder = ". filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", value=True)
|
||||
label="Load training option(s) from saved json file",
|
||||
placeholder=". filename cannot have ',' inside, and files should be splitted by ','.", value="")
|
||||
move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)",
|
||||
value=True)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
||||
value="textual_inversion")
|
||||
|
|
@ -404,4 +417,4 @@ def on_train_tuning(params=None):
|
|||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
return [(train_tuning, "Train Tuning", "train_tuning")]
|
||||
return [(train_tuning, "Train Tuning", "train_tuning")]
|
||||
|
|
|
|||
|
|
@ -31,6 +31,74 @@ from .dataset import PersonalizedBase
|
|||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
||||
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=normal_std)
|
||||
normal_(b, mean=0.0, std=0)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""Residual Block"""
|
||||
def __init__(self, n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device=None, state_dict=None, **kwargs):
|
||||
super().__init__()
|
||||
self.n_outputs = n_outputs
|
||||
self.upsample_layer = None
|
||||
self.upsample = kwargs.get("upsample_model", None)
|
||||
if self.upsample is "Linear":
|
||||
self.upsample_layer = torch.nn.Linear(n_inputs, n_outputs, bias=False)
|
||||
linears = [torch.nn.Linear(n_inputs, n_outputs)]
|
||||
init_weight(linears[0], weight_init, normal_std, activation_func)
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(n_outputs))
|
||||
init_weight(linears[1], weight_init, normal_std, activation_func)
|
||||
if dropout_p > 0:
|
||||
linears.append(torch.nn.Dropout(p=dropout_p))
|
||||
if activation_func == "linear" or activation_func is None:
|
||||
pass
|
||||
elif activation_func in HypernetworkModule.activation_dict:
|
||||
linears.append(HypernetworkModule.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
if state_dict is not None:
|
||||
self.load_state_dict(state_dict)
|
||||
if device is not None:
|
||||
self.to(device)
|
||||
|
||||
def trainables(self, train=False):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
if train:
|
||||
layer.train()
|
||||
else:
|
||||
layer.eval()
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.upsample_layer is None:
|
||||
interpolated = torch.nn.functional.interpolate(x, size=self.n_outputs, mode="nearest-exact")
|
||||
else:
|
||||
interpolated = self.upsample_layer(x)
|
||||
return interpolated + self.linear(x)
|
||||
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
|
|
@ -46,17 +114,25 @@ class HypernetworkModule(torch.nn.Module):
|
|||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None, device=None, generation_seed=None, normal_std=0.01):
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None, device=None, generation_seed=None, normal_std=0.01, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.skip_connection = skip_connection = kwargs.get('skip_connection', False)
|
||||
upsample_linear = kwargs.get('upsample_linear', None)
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
assert dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0, "Dropout Sequence should start and end with probability 0!"
|
||||
assert skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0, "Dropout Sequence should start and end with probability 0!"
|
||||
assert dropout_structure is None or len(dropout_structure) == len(layer_structure), "Dropout Sequence should match length with layer structure!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
if skip_connection:
|
||||
n_inputs, n_outputs = int(dim * layer_structure[i]), int(dim * layer_structure[i+1])
|
||||
dropout_p = dropout_structure[i+1]
|
||||
if activation_func is None:
|
||||
activation_func = "linear"
|
||||
linears.append(ResBlock(n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device, upsample_model=upsample_linear))
|
||||
continue
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
|
@ -131,6 +207,15 @@ class HypernetworkModule(torch.nn.Module):
|
|||
state_dict[to] = x
|
||||
|
||||
def forward(self, x, multiplier=None):
|
||||
if self.skip_connection:
|
||||
if self.training or multiplier is None or not isinstance(multiplier, (int, float)):
|
||||
return self.linear(x)
|
||||
else:
|
||||
resnet_result = self.linear(x)
|
||||
residual = resnet_result - x
|
||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||
multiplier = HypernetworkModule.multiplier
|
||||
return x + multiplier * residual # interpolate
|
||||
if multiplier is None or not isinstance(multiplier, (int, float)):
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
return x + self.linear(x) * multiplier
|
||||
|
|
@ -144,6 +229,8 @@ class HypernetworkModule(torch.nn.Module):
|
|||
layer.eval()
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
elif type(layer) == ResBlock:
|
||||
layer_structure += layer.trainables(train)
|
||||
return layer_structure
|
||||
|
||||
def set_train(self,mode=True):
|
||||
|
|
@ -176,6 +263,8 @@ class Hypernetwork:
|
|||
self.optimizer_state_dict = None
|
||||
self.dropout_structure = kwargs['dropout_structure'] if 'dropout_structure' in kwargs and use_dropout else None
|
||||
self.optional_info = kwargs.get('optional_info', None)
|
||||
self.skip_connection = kwargs.get('skip_connection', False)
|
||||
self.upsample_linear = kwargs.get('upsample_linear', None)
|
||||
generation_seed = kwargs.get('generation_seed', None)
|
||||
normal_std = kwargs.get('normal_std', 0.01)
|
||||
if self.dropout_structure is None:
|
||||
|
|
@ -184,9 +273,11 @@ class Hypernetwork:
|
|||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection,
|
||||
upsample_linear=self.upsample_linear),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection,
|
||||
upsample_linear=self.upsample_linear),
|
||||
)
|
||||
self.eval()
|
||||
|
||||
|
|
@ -237,6 +328,8 @@ class Hypernetwork:
|
|||
state_dict['dropout_structure'] = self.dropout_structure
|
||||
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
||||
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
||||
state_dict['skip_connection'] = self.skip_connection
|
||||
state_dict['upsample_linear'] = self.upsample_linear
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
|
@ -266,6 +359,8 @@ class Hypernetwork:
|
|||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Silent fix for HNs before 4918eb6
|
||||
self.skip_connection = state_dict.get('skip_connection', False)
|
||||
self.upsample_linear = state_dict.get('upsample_linear', False)
|
||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
|
|
@ -296,9 +391,9 @@ class Hypernetwork:
|
|||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from modules import shared, sd_hijack, devices
|
|||
from .hypernetwork import Hypernetwork, train_hypernetwork, load_hypernetwork
|
||||
|
||||
def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||
weight_init_seed=None, normal_std=0.01):
|
||||
weight_init_seed=None, normal_std=0.01, skip_connection=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
assert name, "Name cannot be empty!"
|
||||
|
|
@ -31,7 +31,8 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
|
|||
dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure),
|
||||
optional_info=optional_info,
|
||||
generation_seed=weight_init_seed if weight_init_seed != -1 else None,
|
||||
normal_std=normal_std
|
||||
normal_std=normal_std,
|
||||
skip_connection=skip_connection
|
||||
)
|
||||
hypernet.save(fn)
|
||||
shared.reload_hypernetworks()
|
||||
|
|
@ -41,7 +42,7 @@ def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=
|
|||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None,
|
||||
weight_init_seed=None, normal_std=0.01):
|
||||
weight_init_seed=None, normal_std=0.01, skip_connection=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
assert name, "Name cannot be empty!"
|
||||
|
|
@ -55,7 +56,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
if dropout_structure and type(dropout_structure) == str:
|
||||
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||
normal_std = float(normal_std)
|
||||
assert normal_std > 0, "Normal Standard Deviation should be bigger than 0!"
|
||||
assert normal_std >= 0, "Normal Standard Deviation should be bigger than 0!"
|
||||
hypernet = Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
|
|
@ -67,7 +68,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure),
|
||||
optional_info=optional_info,
|
||||
generation_seed=weight_init_seed if weight_init_seed != -1 else None,
|
||||
normal_std=normal_std
|
||||
normal_std=normal_std,
|
||||
skip_connection=skip_connection
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ def create_extension_tab(params=None):
|
|||
new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0",
|
||||
label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15",
|
||||
placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
|
||||
skip_connection = gr.Checkbox(label="Use skip-connection. Won't work without extension!")
|
||||
optional_info = gr.Textbox("", label="Optional information about Hypernetwork", placeholder="Training information, dateset, etc")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||
|
||||
|
|
@ -165,7 +166,8 @@ def create_extension_tab(params=None):
|
|||
new_hypernetwork_dropout_structure,
|
||||
optional_info,
|
||||
generation_seed if generation_seed.visible else None,
|
||||
normal_std if normal_std.visible else 0.01],
|
||||
normal_std if normal_std.visible else 0.01,
|
||||
skip_connection],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
|
|
@ -185,7 +187,8 @@ def create_extension_tab(params=None):
|
|||
new_hypernetwork_dropout_structure,
|
||||
optional_info,
|
||||
generation_seed if generation_seed.visible else None,
|
||||
normal_std if normal_std.visible else 0.01
|
||||
normal_std if normal_std.visible else 0.01,
|
||||
skip_connection
|
||||
],
|
||||
outputs=[
|
||||
new_hypernetwork_name,
|
||||
|
|
|
|||
Loading…
Reference in New Issue