diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py index 5e06d3d..34a1018 100644 --- a/patches/external_pr/dataset.py +++ b/patches/external_pr/dataset.py @@ -169,6 +169,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.filename = [entry.filename for entry in data] # self.emb_index = [entry.emb_index for entry in data] # print(self.latent_sample.device) diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 57c13dd..78b609e 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -7,13 +7,14 @@ import os import sys import traceback import inspect +from collections import defaultdict, deque import torch import tqdm from modules import shared, sd_models, devices, processing, sd_samplers -from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork +from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork, report_statistics from modules.textual_inversion.learn_schedule import LearnRateScheduler from .textual_inversion import validate_train_inputs, write_loss from ..hypernetwork import Hypernetwork, load_hypernetwork @@ -155,7 +156,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, loss_step = 0 _loss_step = 0 # internal # size = len(ds.indexes) - # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + loss_dict = defaultdict(lambda: deque(maxlen=1024)) # losses = torch.zeros((size,)) # previous_mean_losses = [0] # previous_mean_loss = 0 @@ -196,7 +197,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.cond_stage_model.to(devices.cpu) else: c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) - loss = shared.sd_model(x, c)[0] / gradient_step + loss = shared.sd_model(x, c)[0] + for filenames in batch.filename: + loss_dict[filenames].append(loss.item()) + loss /= gradient_step del x del c @@ -314,8 +318,7 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() hypernetwork.eval() - # report_statistics(loss_dict) - + report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name if shared.opts.save_optimizer_state: