fix weight usage
parent
c3ada0eeb4
commit
95b9506d33
|
|
@ -116,21 +116,19 @@ class PersonalizedBase(Dataset):
|
|||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
weight = torch.ones_like(latent_sample)
|
||||
with torch.autocast("cuda"):
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
if latent_sampling_method == "once" or (
|
||||
latent_sampling_method == "deterministic" and not isinstance(latent_dist,
|
||||
DiagonalGaussianDistribution)):
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
latent_sampling_method = "once"
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "deterministic":
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "random":
|
||||
if latent_sampling_std != -1:
|
||||
|
|
@ -144,8 +142,6 @@ class PersonalizedBase(Dataset):
|
|||
if use_weight and 'A' in image.getbands():
|
||||
alpha_channel = image.getchannel('A')
|
||||
if use_weight and alpha_channel is not None:
|
||||
if latent_sample is None:
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
channels, *latent_size = latent_sample.shape
|
||||
weight_img = alpha_channel.resize(latent_size)
|
||||
npweight = np.array(weight_img).astype(np.float32)
|
||||
|
|
@ -155,12 +151,8 @@ class PersonalizedBase(Dataset):
|
|||
weight -= weight.min()
|
||||
weight /= weight.mean()
|
||||
elif use_weight:
|
||||
if latent_sample is None:
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||
weight = torch.ones_like(latent_sample)
|
||||
else:
|
||||
weight = None
|
||||
entry.weight = weight
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
|
|
|
|||
Loading…
Reference in New Issue