diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py index 1389617..84af89f 100644 --- a/patches/external_pr/dataset.py +++ b/patches/external_pr/dataset.py @@ -130,8 +130,8 @@ class PersonalizedBase(Dataset): elif latent_sampling_method == "random": if latent_sampling_std != -1: assert latent_sampling_std > 0, f"Cannnot apply negative standard deviation {latent_sampling_std}" - print(f"Applying patch, changing std from {latent_dist.std} to {latent_sampling_std}...") - latent_dist.std = latent_sampling_std + print(f"Applying patch, clipping std from {torch.max(latent_dist.std).item()} to {latent_sampling_std}...") + latent_dist.std.clip_(latent_sampling_std) entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) if not (self.tag_drop_out != 0 or self.shuffle_tags):