diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py index 56ea1af..bc5bed4 100644 --- a/patches/external_pr/dataset.py +++ b/patches/external_pr/dataset.py @@ -116,11 +116,11 @@ class PersonalizedBase(Dataset): npimage = (npimage / 127.5 - 1.0).astype(np.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - weight = torch.ones_like(latent_sample) + with torch.autocast("cuda"): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + weight = torch.ones_like(latent_sample) if latent_sampling_method == "once" or ( latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): @@ -209,6 +209,7 @@ class GroupedBatchSampler(Sampler): self.probs = [e % batch_size/self.n_rand_batches/batch_size if self.n_rand_batches > 0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 8c5e892..291c6cf 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -550,7 +550,8 @@ Last saved image: {html.escape(last_saved_image)}
finally: pbar.leave = False pbar.close() - hypernetwork.eval() + if hypernetwork is not None: + hypernetwork.eval() shared.parallel_processing_allowed = old_parallel_processing_allowed if hasattr(sd_hijack_checkpoint, 'remove'): sd_hijack_checkpoint.remove() @@ -815,6 +816,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, shared.sd_model.first_stage_model.to(devices.cpu) weights = hypernetwork.weights(True) + optimizer_name = hypernetwork.optimizer_name if hypernetwork.optimizer_name == 'DAdaptAdamW': use_dadaptation = True optimizer = None @@ -1070,6 +1072,9 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: + if pbar is not None: + pbar.set_description(traceback.format_exc()) + shared.state.textinfo = traceback.format_exc() print(traceback.format_exc(), file=sys.stderr) finally: pbar.leave = False