diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py index ece14f3..56ea1af 100644 --- a/patches/external_pr/dataset.py +++ b/patches/external_pr/dataset.py @@ -116,21 +116,19 @@ class PersonalizedBase(Dataset): npimage = (npimage / 127.5 - 1.0).astype(np.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) - latent_sample = None - + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + weight = torch.ones_like(latent_sample) with torch.autocast("cuda"): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) if latent_sampling_method == "once" or ( latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "random": if latent_sampling_std != -1: @@ -144,8 +142,6 @@ class PersonalizedBase(Dataset): if use_weight and 'A' in image.getbands(): alpha_channel = image.getchannel('A') if use_weight and alpha_channel is not None: - if latent_sample is None: - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) channels, *latent_size = latent_sample.shape weight_img = alpha_channel.resize(latent_size) npweight = np.array(weight_img).astype(np.float32) @@ -155,12 +151,8 @@ class PersonalizedBase(Dataset): weight -= weight.min() weight /= weight.mean() elif use_weight: - if latent_sample is None: - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later weight = torch.ones_like(latent_sample) - else: - weight = None entry.weight = weight if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text)