122 lines
5.5 KiB
Python
122 lines
5.5 KiB
Python
import torch
|
|
import os
|
|
import torchvision.transforms.functional as TF
|
|
from torchvision.utils import make_grid
|
|
import numpy as np
|
|
from IPython import display
|
|
|
|
#
|
|
# Callback functions
|
|
#
|
|
class SamplerCallback(object):
|
|
# Creates the callback function to be passed into the samplers for each step
|
|
def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None,
|
|
verbose=False):
|
|
self.model = root.model
|
|
self.device = root.device
|
|
self.sampler_name = args.sampler
|
|
self.dynamic_threshold = args.dynamic_threshold
|
|
self.static_threshold = args.static_threshold
|
|
self.mask = mask
|
|
self.init_latent = init_latent
|
|
self.sigmas = sigmas
|
|
self.sampler = sampler
|
|
self.verbose = verbose
|
|
|
|
self.batch_size = args.n_samples
|
|
self.save_sample_per_step = args.save_sample_per_step
|
|
self.show_sample_per_step = args.show_sample_per_step
|
|
self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ]
|
|
|
|
if self.save_sample_per_step:
|
|
for path in self.paths_to_image_steps:
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
self.step_index = 0
|
|
|
|
self.noise = None
|
|
if init_latent is not None:
|
|
self.noise = torch.randn_like(init_latent, device=self.device)
|
|
|
|
self.mask_schedule = None
|
|
if sigmas is not None and len(sigmas) > 0:
|
|
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
|
|
elif len(sigmas) == 0:
|
|
self.mask = None # no mask needed if no steps (usually happens because strength==1.0)
|
|
|
|
if self.sampler_name in ["plms","ddim"]:
|
|
if mask is not None:
|
|
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
|
|
|
|
if self.sampler_name in ["plms","ddim"]:
|
|
# Callback function formated for compvis latent diffusion samplers
|
|
self.callback = self.img_callback_
|
|
else:
|
|
# Default callback function uses k-diffusion sampler variables
|
|
self.callback = self.k_callback_
|
|
|
|
self.verbose_print = print if verbose else lambda *args, **kwargs: None
|
|
|
|
def view_sample_step(self, latents, path_name_modifier=''):
|
|
if self.save_sample_per_step or self.show_sample_per_step:
|
|
samples = self.model.decode_first_stage(latents)
|
|
if self.save_sample_per_step:
|
|
fname = f'{path_name_modifier}_{self.step_index:05}.png'
|
|
for i, sample in enumerate(samples):
|
|
sample = sample.double().cpu().add(1).div(2).clamp(0, 1)
|
|
sample = torch.tensor(np.array(sample))
|
|
grid = make_grid(sample, 4).cpu()
|
|
TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))
|
|
if self.show_sample_per_step:
|
|
print(path_name_modifier)
|
|
self.display_images(samples)
|
|
return
|
|
|
|
def display_images(self, images):
|
|
images = images.double().cpu().add(1).div(2).clamp(0, 1)
|
|
images = torch.tensor(np.array(images))
|
|
grid = make_grid(images, 4).cpu()
|
|
display.display(TF.to_pil_image(grid))
|
|
return
|
|
|
|
# The callback function is applied to the image at each step
|
|
def dynamic_thresholding_(self, img, threshold):
|
|
# Dynamic thresholding from Imagen paper (May 2022)
|
|
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
|
|
s = np.max(np.append(s,1.0))
|
|
torch.clamp_(img, -1*s, s)
|
|
torch.FloatTensor.div_(img, s)
|
|
|
|
# Callback for samplers in the k-diffusion repo, called thus:
|
|
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
def k_callback_(self, args_dict):
|
|
self.step_index = args_dict['i']
|
|
if self.dynamic_threshold is not None:
|
|
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)
|
|
if self.static_threshold is not None:
|
|
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)
|
|
if self.mask is not None:
|
|
init_noise = self.init_latent + self.noise * args_dict['sigma']
|
|
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )
|
|
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
|
|
args_dict['x'].copy_(new_img)
|
|
|
|
self.view_sample_step(args_dict['denoised'], "x0_pred")
|
|
|
|
# Callback for Compvis samplers
|
|
# Function that is called on the image (img) and step (i) at each step
|
|
def img_callback_(self, img, i):
|
|
self.step_index = i
|
|
# Thresholding functions
|
|
if self.dynamic_threshold is not None:
|
|
self.dynamic_thresholding_(img, self.dynamic_threshold)
|
|
if self.static_threshold is not None:
|
|
torch.clamp_(img, -1*self.static_threshold, self.static_threshold)
|
|
if self.mask is not None:
|
|
i_inv = len(self.sigmas) - i - 1
|
|
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise)
|
|
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )
|
|
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
|
|
img.copy_(new_img)
|
|
|
|
self.view_sample_step(img, "x") |