fix unboundLocalError
parent
95b9506d33
commit
f747420a27
|
|
@ -116,11 +116,11 @@ class PersonalizedBase(Dataset):
|
|||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
weight = torch.ones_like(latent_sample)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
weight = torch.ones_like(latent_sample)
|
||||
if latent_sampling_method == "once" or (
|
||||
latent_sampling_method == "deterministic" and not isinstance(latent_dist,
|
||||
DiagonalGaussianDistribution)):
|
||||
|
|
@ -209,6 +209,7 @@ class GroupedBatchSampler(Sampler):
|
|||
self.probs = [e % batch_size/self.n_rand_batches/batch_size if self.n_rand_batches > 0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
|
|
|||
|
|
@ -550,7 +550,8 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
if hypernetwork is not None:
|
||||
hypernetwork.eval()
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
|
@ -815,6 +816,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
weights = hypernetwork.weights(True)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
if hypernetwork.optimizer_name == 'DAdaptAdamW':
|
||||
use_dadaptation = True
|
||||
optimizer = None
|
||||
|
|
@ -1070,6 +1072,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
</p>
|
||||
"""
|
||||
except Exception:
|
||||
if pbar is not None:
|
||||
pbar.set_description(traceback.format_exc())
|
||||
shared.state.textinfo = traceback.format_exc()
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue