81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from typing import List, Tuple
|
|
from einops import rearrange
|
|
import numpy as np, os, torch
|
|
from PIL import Image
|
|
from torchvision.utils import make_grid
|
|
import time
|
|
|
|
|
|
def get_output_folder(output_path, batch_folder):
|
|
out_path = os.path.join(output_path,time.strftime('%Y-%m'))
|
|
if batch_folder != "":
|
|
out_path = os.path.join(out_path, batch_folder)
|
|
os.makedirs(out_path, exist_ok=True)
|
|
return out_path
|
|
|
|
|
|
def save_samples(
|
|
args, x_samples: torch.Tensor, seed: int, n_rows: int
|
|
) -> Tuple[Image.Image, List[Image.Image]]:
|
|
"""Function to save samples to disk.
|
|
Args:
|
|
args: Stable deforum diffusion arguments.
|
|
x_samples: Samples to save.
|
|
seed: Seed for the experiment.
|
|
n_rows: Number of rows in the grid.
|
|
Returns:
|
|
A tuple of the grid image and a list of the generated images.
|
|
( grid_image, generated_images )
|
|
"""
|
|
|
|
# save samples
|
|
images = []
|
|
grid_image = None
|
|
if args.display_samples or args.save_samples:
|
|
for index, x_sample in enumerate(x_samples):
|
|
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
images.append(Image.fromarray(x_sample.astype(np.uint8)))
|
|
if args.save_samples:
|
|
images[-1].save(
|
|
os.path.join(
|
|
args.outdir, f"{args.timestring}_{index:02}_{seed}.png"
|
|
)
|
|
)
|
|
|
|
# save grid
|
|
if args.display_grid or args.save_grid:
|
|
grid = torch.stack([x_samples], 0)
|
|
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
grid = make_grid(grid, nrow=n_rows, padding=0)
|
|
|
|
# to image
|
|
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
grid_image = Image.fromarray(grid.astype(np.uint8))
|
|
if args.save_grid:
|
|
grid_image.save(
|
|
os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png")
|
|
)
|
|
|
|
# return grid_image and individual sample images
|
|
return grid_image, images
|
|
|
|
def save_image(image, image_type, filename, args, video_args, root):
|
|
if video_args.store_frames_in_ram:
|
|
root.frames_cache.append({'path':os.path.join(args.outdir, filename), 'image':image, 'image_type':image_type})
|
|
else:
|
|
image.save(os.path.join(args.outdir, filename))
|
|
|
|
import cv2, gc
|
|
|
|
def reset_frames_cache(root):
|
|
root.frames_cache = []
|
|
gc.collect()
|
|
|
|
def dump_frames_cache(root):
|
|
for image_cache in root.frames_cache:
|
|
if image_cache['image_type'] == 'cv2':
|
|
cv2.imwrite(image_cache['path'], image_cache['image'])
|
|
elif image_cache['image_type'] == 'PIL':
|
|
image_cache['image'].save(image_cache['path'])
|
|
# do not reset the cache since we're going to add frame erasing later function #TODO
|