diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 006cfd0..6bc2d83 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -14,7 +14,7 @@ import tqdm from modules import shared, sd_models, devices, processing, sd_samplers from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork, report_statistics from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.textual_inversion import tensorboard_setup, tensorboard_add, tensorboard_add_image +from ..tbutils import tensorboard_setup, tensorboard_add, tensorboard_add_image, tensorboard_log_hyperparameter from .textual_inversion import validate_train_inputs, write_loss from ..hypernetwork import Hypernetwork, load_hypernetwork from . import sd_hijack_checkpoint @@ -379,7 +379,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, - learn_rate=scheduler.learn_rate, epoch_num=epoch_num) + learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num) if images_dir is not None and ( use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( create_image_every > 0 and steps_done % create_image_every == 0): @@ -483,6 +483,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, load_hypernetworks_option='', load_training_options=''): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images + base_hypernetwork_name = hypernetwork_name if load_hypernetworks_option != '': timeStr = time.strftime('%Y%m%d%H%M%S') dump_hyper: dict = get_training_option(load_hypernetworks_option) @@ -626,7 +627,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - + base_log_directory = log_directory log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) unload = shared.opts.unload_models_when_training @@ -653,7 +654,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, scheduler = LearnRateScheduler(learn_rate, steps, initial_step) if shared.opts.training_enable_tensorboard: print("Tensorboard logging enabled") - tensorboard_writer = tensorboard_setup(log_directory) + tensorboard_writer = tensorboard_setup(os.path.join(base_log_directory, base_hypernetwork_name)) + else: tensorboard_writer = None # dataset loading may take a while, so input validations and early returns should be done before this @@ -831,7 +833,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, - learn_rate=scheduler.learn_rate, epoch_num=epoch_num) + learn_rate=scheduler.learn_rate, epoch_num=epoch_num,base_name=hypernetwork_name) if images_dir is not None and ( use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( create_image_every > 0 and steps_done % create_image_every == 0): @@ -913,6 +915,33 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() shared.parallel_processing_allowed = old_parallel_processing_allowed + if shared.opts.training_enable_tensorboard: + tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate, + GA_steps=gradient_step, + batch_size=batch_size, + layer_structure=hypernetwork.layer_structure, + activation=hypernetwork.activation_func, + weight_init=hypernetwork.weight_init, + dropout_structure=hypernetwork.dropout_structure, + max_steps=steps, + latent_sampling_method=latent_sampling_method, + template=template_file, + CosineAnnealing=use_beta_scheduler, + beta_repeat_epoch=beta_repeat_epoch, + epoch_mult=epoch_mult, + warmup=warmup, + min_lr=min_lr, + gamma_rate=gamma_rate, + adamW_opts=use_adamw_parameter, + adamW_decay=adamw_weight_decay, + adamW_beta_1=adamw_beta_1, + adamW_beta_2=adamw_beta_2, + adamW_eps=adamw_eps, + gradient_clip=gradient_clip_opt, + gradient_clip_value=optional_gradient_clip_value, + gradient_clip_norm_type=optional_gradient_norm_type, + loss=mean_loss + ) report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name diff --git a/patches/tbutils.py b/patches/tbutils.py new file mode 100644 index 0000000..6454ab4 --- /dev/null +++ b/patches/tbutils.py @@ -0,0 +1,69 @@ +import os + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from modules import shared + + +def tensorboard_setup(log_directory): + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + return SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + +def tensorboard_log_hyperparameter(tensorboard_writer:SummaryWriter, **kwargs): + for keys in kwargs: + if type(kwargs[keys]) not in [bool, str, float, int,None]: + kwargs[keys] = str(kwargs[keys]) + tensorboard_writer.add_hparams({ + 'lr' : kwargs.get('lr', 0.01), + 'GA steps' : kwargs.get('GA_steps', 1), + 'bsize' : kwargs.get('batch_size', 1), + 'layer structure' : kwargs.get('layer_structure', '1,2,1'), + 'activation' : kwargs.get('activation', 'Linear'), + 'weight_init' : kwargs.get('weight_init', 'Normal'), + 'dropout_structure' : kwargs.get('dropout_structure', '0,0,0'), + 'steps' : kwargs.get('max_steps', 10000), + 'latent sampling': kwargs.get('latent_sampling_method', 'once'), + 'template file': kwargs.get('template', 'nothing'), + 'CosineAnnealing' : kwargs.get('CosineAnnealing', False), + 'beta_repeat epoch': kwargs.get('beta_repeat_epoch', 0), + 'epoch_mult':kwargs.get('epoch_mult', 1), + 'warmup_step' : kwargs.get('warmup', 5), + 'min_lr' : kwargs.get('min_lr', 6e-7), + 'decay' : kwargs.get('gamma_rate', 1), + 'adamW' : kwargs.get('adamW_opts', False), + 'adamW_decay' : kwargs.get('adamW_decay', 0.01), + 'adamW_beta1' : kwargs.get('adamW_beta_1', 0.9), + 'adamW_beta2': kwargs.get('adamW_beta_2', 0.99), + 'adamW_eps': kwargs.get('adamW_eps', 1e-8), + 'gradient_clip_opt':kwargs.get('gradient_clip', 'None'), + 'gradient_clip_value' : kwargs.get('gradient_clip_value', 1e-1), + 'gradient_clip_norm' : kwargs.get('gradient_clip_norm_type', 2) + }, + {'hparam/loss' : kwargs.get('loss', 0.0)} + ) +def tensorboard_add(tensorboard_writer:SummaryWriter, loss, global_step, step, learn_rate, epoch_num, base_name=""): + prefix = base_name + "/" if base_name else "" + tensorboard_add_scaler(tensorboard_writer, prefix+"Loss/train", loss, global_step) + tensorboard_add_scaler(tensorboard_writer, prefix+f"Loss/train/epoch-{epoch_num}", loss, step) + tensorboard_add_scaler(tensorboard_writer, prefix+"Learn rate/train", learn_rate, global_step) + tensorboard_add_scaler(tensorboard_writer, prefix+f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) + + +def tensorboard_add_scaler(tensorboard_writer:SummaryWriter, tag, value, step): + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) + + +def tensorboard_add_image(tensorboard_writer:SummaryWriter, tag, pil_image, step, base_name=""): + # Convert a pil image to a torch tensor + prefix = base_name + "/" if base_name else "" + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) + + tensorboard_writer.add_image(prefix+tag, img_tensor, global_step=step)