Add loss statistics again

By request, it was working anyway
beta-apply-bigger-batch-sizes
aria1th 2022-11-29 19:45:21 +09:00
parent e3d7c692af
commit 6e6e02aea6
2 changed files with 9 additions and 5 deletions

View File

@ -169,6 +169,7 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
self.filename = [entry.filename for entry in data]
# self.emb_index = [entry.emb_index for entry in data]
# print(self.latent_sample.device)

View File

@ -7,13 +7,14 @@ import os
import sys
import traceback
import inspect
from collections import defaultdict, deque
import torch
import tqdm
from modules import shared, sd_models, devices, processing, sd_samplers
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork, report_statistics
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from .textual_inversion import validate_train_inputs, write_loss
from ..hypernetwork import Hypernetwork, load_hypernetwork
@ -155,7 +156,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
loss_step = 0
_loss_step = 0 # internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
loss_dict = defaultdict(lambda: deque(maxlen=1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
@ -196,7 +197,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
loss = shared.sd_model(x, c)[0]
for filenames in batch.filename:
loss_dict[filenames].append(loss.item())
loss /= gradient_step
del x
del c
@ -314,8 +318,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
hypernetwork.eval()
# report_statistics(loss_dict)
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state: