272 lines
13 KiB
Python
272 lines
13 KiB
Python
import torch
|
|
from PIL import Image
|
|
import requests
|
|
import numpy as np
|
|
import torchvision.transforms.functional as TF
|
|
from pytorch_lightning import seed_everything
|
|
import os
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from k_diffusion.external import CompVisDenoiser
|
|
from torch import autocast
|
|
from contextlib import nullcontext
|
|
from einops import rearrange, repeat
|
|
|
|
from .prompt import get_uc_and_c
|
|
from .k_samplers import sampler_fn
|
|
from scipy.ndimage import gaussian_filter
|
|
|
|
from .callback import SamplerCallback
|
|
|
|
def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:
|
|
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
|
|
|
|
def load_img(path, shape, use_alpha_as_mask=False):
|
|
# use_alpha_as_mask: Read the alpha channel of the image as the mask image
|
|
if path.startswith('http://') or path.startswith('https://'):
|
|
image = Image.open(requests.get(path, stream=True).raw)
|
|
else:
|
|
image = Image.open(path)
|
|
|
|
if use_alpha_as_mask:
|
|
image = image.convert('RGBA')
|
|
else:
|
|
image = image.convert('RGB')
|
|
|
|
image = image.resize(shape, resample=Image.LANCZOS)
|
|
|
|
mask_image = None
|
|
if use_alpha_as_mask:
|
|
# Split alpha channel into a mask_image
|
|
red, green, blue, alpha = Image.Image.split(image)
|
|
mask_image = alpha.convert('L')
|
|
image = image.convert('RGB')
|
|
|
|
image = np.array(image).astype(np.float16) / 255.0
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image)
|
|
image = 2.*image - 1.
|
|
|
|
return image, mask_image
|
|
|
|
def load_mask_latent(mask_input, shape):
|
|
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
|
|
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
|
|
|
|
if isinstance(mask_input, str): # mask input is probably a file name
|
|
if mask_input.startswith('http://') or mask_input.startswith('https://'):
|
|
mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')
|
|
else:
|
|
mask_image = Image.open(mask_input).convert('RGBA')
|
|
elif isinstance(mask_input, Image.Image):
|
|
mask_image = mask_input
|
|
else:
|
|
raise Exception("mask_input must be a PIL image or a file name")
|
|
|
|
mask_w_h = (shape[-1], shape[-2])
|
|
mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)
|
|
mask = mask.convert("L")
|
|
return mask
|
|
|
|
def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0, invert_mask=False):
|
|
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
|
|
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
|
|
# mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge,
|
|
# 0 is black, 1 is no adjustment, >1 is brighter
|
|
# mask_contrast_adjust (non-negative float): amount to adjust contrast of the image,
|
|
# 0 is a flat grey image, 1 is no adjustment, >1 is more contrast
|
|
|
|
mask = load_mask_latent(mask_input, mask_shape)
|
|
|
|
# Mask brightness/contrast adjustments
|
|
if mask_brightness_adjust != 1:
|
|
mask = TF.adjust_brightness(mask, mask_brightness_adjust)
|
|
if mask_contrast_adjust != 1:
|
|
mask = TF.adjust_contrast(mask, mask_contrast_adjust)
|
|
|
|
# Mask image to array
|
|
mask = np.array(mask).astype(np.float32) / 255.0
|
|
mask = np.tile(mask,(4,1,1))
|
|
mask = np.expand_dims(mask,axis=0)
|
|
mask = torch.from_numpy(mask)
|
|
|
|
if invert_mask:
|
|
mask = ( (mask - 0.5) * -1) + 0.5
|
|
|
|
mask = np.clip(mask,0,1)
|
|
return mask
|
|
|
|
def generate(args, root, frame = 0, return_latent=False, return_sample=False, return_c=False):
|
|
seed_everything(args.seed)
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
|
|
sampler = PLMSSampler(root.model) if args.sampler == 'plms' else DDIMSampler(root.model)
|
|
model_wrap = CompVisDenoiser(root.model)
|
|
batch_size = args.n_samples
|
|
prompt = args.prompt
|
|
assert prompt is not None
|
|
data = [batch_size * [prompt]]
|
|
precision_scope = autocast if args.precision == "autocast" else nullcontext
|
|
|
|
init_latent = None
|
|
mask_image = None
|
|
init_image = None
|
|
if args.init_latent is not None:
|
|
init_latent = args.init_latent
|
|
elif args.init_sample is not None:
|
|
with precision_scope("cuda"):
|
|
init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(args.init_sample))
|
|
elif args.use_init and args.init_image != None and args.init_image != '':
|
|
init_image, mask_image = load_img(args.init_image,
|
|
shape=(args.W, args.H),
|
|
use_alpha_as_mask=args.use_alpha_as_mask)
|
|
init_image = init_image.to(root.device)
|
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
|
with precision_scope("cuda"):
|
|
init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(init_image)) # move to latent space
|
|
|
|
if not args.use_init and args.strength > 0 and args.strength_0_no_init:
|
|
print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.")
|
|
print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n")
|
|
args.strength = 0
|
|
|
|
# Mask functions
|
|
if args.use_mask:
|
|
assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel"
|
|
assert args.use_init, "use_mask==True: use_init is required for a mask"
|
|
assert init_latent is not None, "use_mask==True: An latent init image is required for a mask"
|
|
|
|
|
|
mask = prepare_mask(args.mask_file if mask_image is None else mask_image,
|
|
init_latent.shape,
|
|
args.mask_contrast_adjust,
|
|
args.mask_brightness_adjust,
|
|
args.invert_mask)
|
|
|
|
if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:
|
|
raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
|
|
|
|
mask = mask.to(root.device)
|
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
|
else:
|
|
mask = None
|
|
|
|
assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True"
|
|
|
|
t_enc = int((1.0-args.strength) * args.steps)
|
|
|
|
# Noise schedule for the k-diffusion samplers (used for masking)
|
|
k_sigmas = model_wrap.get_sigmas(args.steps)
|
|
k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]
|
|
|
|
if args.sampler in ['plms','ddim']:
|
|
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
|
|
|
|
callback = SamplerCallback(args=args,
|
|
root=root,
|
|
mask=mask,
|
|
init_latent=init_latent,
|
|
sigmas=k_sigmas,
|
|
sampler=sampler,
|
|
verbose=False).callback
|
|
|
|
results = []
|
|
with torch.no_grad():
|
|
with precision_scope("cuda"):
|
|
with root.model.ema_scope():
|
|
for prompts in data:
|
|
if isinstance(prompts, tuple):
|
|
prompts = list(prompts)
|
|
if args.prompt_weighting:
|
|
uc, c = get_uc_and_c(prompts, root.model, args, frame)
|
|
else:
|
|
uc = root.model.get_learned_conditioning(batch_size * [""])
|
|
c = root.model.get_learned_conditioning(prompts)
|
|
|
|
|
|
if args.scale == 1.0:
|
|
uc = None
|
|
if args.init_c != None:
|
|
c = args.init_c
|
|
|
|
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
|
|
samples = sampler_fn(
|
|
c=c,
|
|
uc=uc,
|
|
args=args,
|
|
model_wrap=model_wrap,
|
|
init_latent=init_latent,
|
|
t_enc=t_enc,
|
|
device=root.device,
|
|
cb=callback)
|
|
else:
|
|
# args.sampler == 'plms' or args.sampler == 'ddim':
|
|
if init_latent is not None and args.strength > 0:
|
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(root.device))
|
|
else:
|
|
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=root.device)
|
|
if args.sampler == 'ddim':
|
|
samples = sampler.decode(z_enc,
|
|
c,
|
|
t_enc,
|
|
unconditional_guidance_scale=args.scale,
|
|
unconditional_conditioning=uc,
|
|
img_callback=callback)
|
|
elif args.sampler == 'plms': # no "decode" function in plms, so use "sample"
|
|
shape = [args.C, args.H // args.f, args.W // args.f]
|
|
samples, _ = sampler.sample(S=args.steps,
|
|
conditioning=c,
|
|
batch_size=args.n_samples,
|
|
shape=shape,
|
|
verbose=False,
|
|
unconditional_guidance_scale=args.scale,
|
|
unconditional_conditioning=uc,
|
|
eta=args.ddim_eta,
|
|
x_T=z_enc,
|
|
img_callback=callback)
|
|
else:
|
|
raise Exception(f"Sampler {args.sampler} not recognised.")
|
|
|
|
|
|
if return_latent:
|
|
results.append(samples.clone())
|
|
|
|
x_samples = root.model.decode_first_stage(samples)
|
|
|
|
if args.use_mask and args.overlay_mask:
|
|
# Overlay the masked image after the image is generated
|
|
if args.init_sample is not None:
|
|
img_original = args.init_sample
|
|
elif init_image is not None:
|
|
img_original = init_image
|
|
else:
|
|
raise Exception("Cannot overlay the masked image without an init image to overlay")
|
|
|
|
mask_fullres = prepare_mask(args.mask_file if mask_image is None else mask_image,
|
|
img_original.shape,
|
|
args.mask_contrast_adjust,
|
|
args.mask_brightness_adjust,
|
|
args.inver_mask)
|
|
mask_fullres = mask_fullres[:,:3,:,:]
|
|
mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=batch_size)
|
|
|
|
mask_fullres[mask_fullres < mask_fullres.max()] = 0
|
|
mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur)
|
|
mask_fullres = torch.Tensor(mask_fullres).to(root.device)
|
|
|
|
x_samples = img_original * mask_fullres + x_samples * ((mask_fullres * -1.0) + 1)
|
|
|
|
|
|
if return_sample:
|
|
results.append(x_samples.clone())
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
if return_c:
|
|
results.append(c.clone())
|
|
|
|
for x_sample in x_samples:
|
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
|
image = Image.fromarray(x_sample.astype(np.uint8))
|
|
results.append(image)
|
|
return results |