1728 lines
90 KiB
Plaintext
1728 lines
90 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c442uQJ_gUgy"
|
|
},
|
|
"source": [
|
|
"# **Deforum Stable Diffusion v0.4**\n",
|
|
"[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).\n",
|
|
"\n",
|
|
"Notebook by [deforum](https://discord.gg/upmXXsrwZc)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "T4knibRpAQ06"
|
|
},
|
|
"source": [
|
|
"# Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2g-f7cQmf2Nt",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **NVIDIA GPU**\n",
|
|
"import subprocess\n",
|
|
"sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
"print(sub_p_res)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "TxIOPT0G5Lx1"
|
|
},
|
|
"source": [
|
|
"#@markdown **Model and Output Paths**\n",
|
|
"# ask for the link\n",
|
|
"print(\"Local Path Variables:\\n\")\n",
|
|
"\n",
|
|
"models_path = \"/content/models\" #@param {type:\"string\"}\n",
|
|
"output_path = \"/content/output\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"#@markdown **Google Drive Path Variables (Optional)**\n",
|
|
"mount_google_drive = True #@param {type:\"boolean\"}\n",
|
|
"force_remount = False\n",
|
|
"\n",
|
|
"if mount_google_drive:\n",
|
|
" from google.colab import drive # type: ignore\n",
|
|
" try:\n",
|
|
" drive_path = \"/content/drive\"\n",
|
|
" drive.mount(drive_path,force_remount=force_remount)\n",
|
|
" models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n",
|
|
" output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n",
|
|
" models_path = models_path_gdrive\n",
|
|
" output_path = output_path_gdrive\n",
|
|
" except:\n",
|
|
" print(\"...error mounting drive or with drive path variables\")\n",
|
|
" print(\"...reverting to default path variables\")\n",
|
|
"\n",
|
|
"import os\n",
|
|
"os.makedirs(models_path, exist_ok=True)\n",
|
|
"os.makedirs(output_path, exist_ok=True)\n",
|
|
"\n",
|
|
"print(f\"models_path: {models_path}\")\n",
|
|
"print(f\"output_path: {output_path}\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "VRNl2mfepEIe",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **Setup Environment**\n",
|
|
"\n",
|
|
"setup_environment = True #@param {type:\"boolean\"}\n",
|
|
"print_subprocess = False #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"if setup_environment:\n",
|
|
" import subprocess, time\n",
|
|
" print(\"Setting up environment...\")\n",
|
|
" start_time = time.time()\n",
|
|
" all_process = [\n",
|
|
" ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
|
|
" ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],\n",
|
|
" ['git', 'clone', '-b', 'dev', 'https://github.com/deforum/stable-diffusion'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
|
|
" ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],\n",
|
|
" ['git', 'clone', 'https://github.com/shariqfarooq123/AdaBins.git'],\n",
|
|
" ['git', 'clone', 'https://github.com/isl-org/MiDaS.git'],\n",
|
|
" ['git', 'clone', 'https://github.com/MSFTserver/pytorch3d-lite.git'],\n",
|
|
" ]\n",
|
|
" for process in all_process:\n",
|
|
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
" if print_subprocess:\n",
|
|
" print(running)\n",
|
|
" \n",
|
|
" print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
|
|
" with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n",
|
|
" f.write('')\n",
|
|
"\n",
|
|
" end_time = time.time()\n",
|
|
" print(f\"Environment set up in {end_time-start_time:.0f} seconds\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "81qmVZbrm4uu",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **Python Definitions**\n",
|
|
"import json\n",
|
|
"from IPython import display\n",
|
|
"\n",
|
|
"import gc, math, os, pathlib, subprocess, sys, time\n",
|
|
"import cv2\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import random\n",
|
|
"import requests\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torchvision.transforms as T\n",
|
|
"import torchvision.transforms.functional as TF\n",
|
|
"from contextlib import contextmanager, nullcontext\n",
|
|
"from einops import rearrange, repeat\n",
|
|
"from omegaconf import OmegaConf\n",
|
|
"from PIL import Image\n",
|
|
"from pytorch_lightning import seed_everything\n",
|
|
"from skimage.exposure import match_histograms\n",
|
|
"from torchvision.utils import make_grid\n",
|
|
"from tqdm import tqdm, trange\n",
|
|
"from types import SimpleNamespace\n",
|
|
"from torch import autocast\n",
|
|
"from scipy.ndimage import gaussian_filter\n",
|
|
"\n",
|
|
"sys.path.extend([\n",
|
|
" 'src/taming-transformers',\n",
|
|
" 'src/clip',\n",
|
|
" 'stable-diffusion/',\n",
|
|
" 'k-diffusion',\n",
|
|
" 'pytorch3d-lite',\n",
|
|
" 'AdaBins',\n",
|
|
" 'MiDaS',\n",
|
|
"])\n",
|
|
"\n",
|
|
"import py3d_tools as p3d\n",
|
|
"\n",
|
|
"from helpers import DepthModel, sampler_fn\n",
|
|
"from k_diffusion.external import CompVisDenoiser\n",
|
|
"from ldm.util import instantiate_from_config\n",
|
|
"from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
"from ldm.models.diffusion.plms import PLMSSampler\n",
|
|
"\n",
|
|
"def sanitize(prompt):\n",
|
|
" whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')\n",
|
|
" tmp = ''.join(filter(whitelist.__contains__, prompt))\n",
|
|
" return tmp.replace(' ', '_')\n",
|
|
"\n",
|
|
"from functools import reduce\n",
|
|
"def construct_RotationMatrixHomogenous(rotation_angles):\n",
|
|
" assert(type(rotation_angles)==list and len(rotation_angles)==3)\n",
|
|
" RH = np.eye(4,4)\n",
|
|
" cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3])\n",
|
|
" return RH\n",
|
|
"\n",
|
|
"# https://en.wikipedia.org/wiki/Rotation_matrix\n",
|
|
"def getRotationMatrixManual(rotation_angles):\n",
|
|
"\t\n",
|
|
" rotation_angles = [np.deg2rad(x) for x in rotation_angles]\n",
|
|
" \n",
|
|
" phi = rotation_angles[0] # around x\n",
|
|
" gamma = rotation_angles[1] # around y\n",
|
|
" theta = rotation_angles[2] # around z\n",
|
|
" \n",
|
|
" # X rotation\n",
|
|
" Rphi = np.eye(4,4)\n",
|
|
" sp = np.sin(phi)\n",
|
|
" cp = np.cos(phi)\n",
|
|
" Rphi[1,1] = cp\n",
|
|
" Rphi[2,2] = Rphi[1,1]\n",
|
|
" Rphi[1,2] = -sp\n",
|
|
" Rphi[2,1] = sp\n",
|
|
" \n",
|
|
" # Y rotation\n",
|
|
" Rgamma = np.eye(4,4)\n",
|
|
" sg = np.sin(gamma)\n",
|
|
" cg = np.cos(gamma)\n",
|
|
" Rgamma[0,0] = cg\n",
|
|
" Rgamma[2,2] = Rgamma[0,0]\n",
|
|
" Rgamma[0,2] = sg\n",
|
|
" Rgamma[2,0] = -sg\n",
|
|
" \n",
|
|
" # Z rotation (in-image-plane)\n",
|
|
" Rtheta = np.eye(4,4)\n",
|
|
" st = np.sin(theta)\n",
|
|
" ct = np.cos(theta)\n",
|
|
" Rtheta[0,0] = ct\n",
|
|
" Rtheta[1,1] = Rtheta[0,0]\n",
|
|
" Rtheta[0,1] = -st\n",
|
|
" Rtheta[1,0] = st\n",
|
|
" \n",
|
|
" R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta]) \n",
|
|
" \n",
|
|
" return R\n",
|
|
"\n",
|
|
"\n",
|
|
"def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength):\n",
|
|
" \n",
|
|
" ptsIn2D = ptsIn[0,:]\n",
|
|
" ptsOut2D = ptsOut[0,:]\n",
|
|
" ptsOut2Dlist = []\n",
|
|
" ptsIn2Dlist = []\n",
|
|
" \n",
|
|
" for i in range(0,4):\n",
|
|
" ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]])\n",
|
|
" ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]])\n",
|
|
" \n",
|
|
" pin = np.array(ptsIn2Dlist) + [W/2.,H/2.]\n",
|
|
" pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength)\n",
|
|
" pin = pin.astype(np.float32)\n",
|
|
" pout = pout.astype(np.float32)\n",
|
|
" \n",
|
|
" return pin, pout\n",
|
|
"\n",
|
|
"def warpMatrix(W, H, theta, phi, gamma, scale, fV):\n",
|
|
" \n",
|
|
" # M is to be estimated\n",
|
|
" M = np.eye(4, 4)\n",
|
|
" \n",
|
|
" fVhalf = np.deg2rad(fV/2.)\n",
|
|
" d = np.sqrt(W*W+H*H)\n",
|
|
" sideLength = scale*d/np.cos(fVhalf)\n",
|
|
" h = d/(2.0*np.sin(fVhalf))\n",
|
|
" n = h-(d/2.0);\n",
|
|
" f = h+(d/2.0);\n",
|
|
" \n",
|
|
" # Translation along Z-axis by -h\n",
|
|
" T = np.eye(4,4)\n",
|
|
" T[2,3] = -h\n",
|
|
" \n",
|
|
" # Rotation matrices around x,y,z\n",
|
|
" R = getRotationMatrixManual([phi, gamma, theta])\n",
|
|
" \n",
|
|
" \n",
|
|
" # Projection Matrix \n",
|
|
" P = np.eye(4,4)\n",
|
|
" P[0,0] = 1.0/np.tan(fVhalf)\n",
|
|
" P[1,1] = P[0,0]\n",
|
|
" P[2,2] = -(f+n)/(f-n)\n",
|
|
" P[2,3] = -(2.0*f*n)/(f-n)\n",
|
|
" P[3,2] = -1.0\n",
|
|
" \n",
|
|
" # pythonic matrix multiplication\n",
|
|
" F = reduce(lambda x,y : np.matmul(x,y), [P, T, R]) \n",
|
|
" \n",
|
|
" # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. \n",
|
|
" # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3);\n",
|
|
" ptsIn = np.array([[\n",
|
|
" [-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.]\n",
|
|
" ]])\n",
|
|
" ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype))\n",
|
|
" ptsOut = cv2.perspectiveTransform(ptsIn, F)\n",
|
|
" \n",
|
|
" ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength)\n",
|
|
" \n",
|
|
" # check float32 otherwise OpenCV throws an error\n",
|
|
" assert(ptsInPt2f.dtype == np.float32)\n",
|
|
" assert(ptsOutPt2f.dtype == np.float32)\n",
|
|
" M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f)\n",
|
|
"\n",
|
|
" return M33, sideLength\n",
|
|
"\n",
|
|
"def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):\n",
|
|
" angle = keys.angle_series[frame_idx]\n",
|
|
" zoom = keys.zoom_series[frame_idx]\n",
|
|
" translation_x = keys.translation_x_series[frame_idx]\n",
|
|
" translation_y = keys.translation_y_series[frame_idx]\n",
|
|
"\n",
|
|
" center = (args.W // 2, args.H // 2)\n",
|
|
" trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])\n",
|
|
" rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)\n",
|
|
" trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
|
|
" rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
|
|
" if anim_args.flip_2d_perspective:\n",
|
|
" perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx]\n",
|
|
" perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx]\n",
|
|
" perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx]\n",
|
|
" perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx]\n",
|
|
" M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv);\n",
|
|
" post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]])\n",
|
|
" post_trans_mat = np.vstack([post_trans_mat, [0,0,1]])\n",
|
|
" bM = np.matmul(M, post_trans_mat)\n",
|
|
" xform = np.matmul(bM, rot_mat, trans_mat)\n",
|
|
" else:\n",
|
|
" xform = np.matmul(rot_mat, trans_mat)\n",
|
|
"\n",
|
|
" return cv2.warpPerspective(\n",
|
|
" prev_img_cv2,\n",
|
|
" xform,\n",
|
|
" (prev_img_cv2.shape[1], prev_img_cv2.shape[0]),\n",
|
|
" borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE\n",
|
|
" )\n",
|
|
"\n",
|
|
"def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx):\n",
|
|
" TRANSLATION_SCALE = 1.0/200.0 # matches Disco\n",
|
|
" translate_xyz = [\n",
|
|
" -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, \n",
|
|
" keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, \n",
|
|
" -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE\n",
|
|
" ]\n",
|
|
" rotate_xyz = [\n",
|
|
" math.radians(keys.rotation_3d_x_series[frame_idx]), \n",
|
|
" math.radians(keys.rotation_3d_y_series[frame_idx]), \n",
|
|
" math.radians(keys.rotation_3d_z_series[frame_idx])\n",
|
|
" ]\n",
|
|
" rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n",
|
|
" result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" return result\n",
|
|
"\n",
|
|
"def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:\n",
|
|
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
|
|
"\n",
|
|
"def get_output_folder(output_path, batch_folder):\n",
|
|
" out_path = os.path.join(output_path,time.strftime('%Y-%m'))\n",
|
|
" if batch_folder != \"\":\n",
|
|
" out_path = os.path.join(out_path, batch_folder)\n",
|
|
" os.makedirs(out_path, exist_ok=True)\n",
|
|
" return out_path\n",
|
|
"\n",
|
|
"def load_img(path, shape, use_alpha_as_mask=False):\n",
|
|
" # use_alpha_as_mask: Read the alpha channel of the image as the mask image\n",
|
|
" if path.startswith('http://') or path.startswith('https://'):\n",
|
|
" image = Image.open(requests.get(path, stream=True).raw)\n",
|
|
" else:\n",
|
|
" image = Image.open(path)\n",
|
|
"\n",
|
|
" if use_alpha_as_mask:\n",
|
|
" image = image.convert('RGBA')\n",
|
|
" else:\n",
|
|
" image = image.convert('RGB')\n",
|
|
"\n",
|
|
" image = image.resize(shape, resample=Image.LANCZOS)\n",
|
|
"\n",
|
|
" mask_image = None\n",
|
|
" if use_alpha_as_mask:\n",
|
|
" # Split alpha channel into a mask_image\n",
|
|
" red, green, blue, alpha = Image.Image.split(image)\n",
|
|
" mask_image = alpha.convert('L')\n",
|
|
" image = image.convert('RGB')\n",
|
|
"\n",
|
|
" image = np.array(image).astype(np.float16) / 255.0\n",
|
|
" image = image[None].transpose(0, 3, 1, 2)\n",
|
|
" image = torch.from_numpy(image)\n",
|
|
" image = 2.*image - 1.\n",
|
|
"\n",
|
|
" return image, mask_image\n",
|
|
"\n",
|
|
"def load_mask_latent(mask_input, shape):\n",
|
|
" # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object\n",
|
|
" # shape (list-like len(4)): shape of the image to match, usually latent_image.shape\n",
|
|
" \n",
|
|
" if isinstance(mask_input, str): # mask input is probably a file name\n",
|
|
" if mask_input.startswith('http://') or mask_input.startswith('https://'):\n",
|
|
" mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')\n",
|
|
" else:\n",
|
|
" mask_image = Image.open(mask_input).convert('RGBA')\n",
|
|
" elif isinstance(mask_input, Image.Image):\n",
|
|
" mask_image = mask_input\n",
|
|
" else:\n",
|
|
" raise Exception(\"mask_input must be a PIL image or a file name\")\n",
|
|
"\n",
|
|
" mask_w_h = (shape[-1], shape[-2])\n",
|
|
" mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)\n",
|
|
" mask = mask.convert(\"L\")\n",
|
|
" return mask\n",
|
|
"\n",
|
|
"def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):\n",
|
|
" # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object\n",
|
|
" # shape (list-like len(4)): shape of the image to match, usually latent_image.shape\n",
|
|
" # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, \n",
|
|
" # 0 is black, 1 is no adjustment, >1 is brighter\n",
|
|
" # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, \n",
|
|
" # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast\n",
|
|
" \n",
|
|
" mask = load_mask_latent(mask_input, mask_shape)\n",
|
|
"\n",
|
|
" # Mask brightness/contrast adjustments\n",
|
|
" if mask_brightness_adjust != 1:\n",
|
|
" mask = TF.adjust_brightness(mask, mask_brightness_adjust)\n",
|
|
" if mask_contrast_adjust != 1:\n",
|
|
" mask = TF.adjust_contrast(mask, mask_contrast_adjust)\n",
|
|
"\n",
|
|
" # Mask image to array\n",
|
|
" mask = np.array(mask).astype(np.float32) / 255.0\n",
|
|
" mask = np.tile(mask,(4,1,1))\n",
|
|
" mask = np.expand_dims(mask,axis=0)\n",
|
|
" mask = torch.from_numpy(mask)\n",
|
|
"\n",
|
|
" if args.invert_mask:\n",
|
|
" mask = ( (mask - 0.5) * -1) + 0.5\n",
|
|
" \n",
|
|
" mask = np.clip(mask,0,1)\n",
|
|
" return mask\n",
|
|
"\n",
|
|
"def maintain_colors(prev_img, color_match_sample, mode):\n",
|
|
" if mode == 'Match Frame 0 RGB':\n",
|
|
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
|
|
" elif mode == 'Match Frame 0 HSV':\n",
|
|
" prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
|
|
" color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
|
|
" matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
|
|
" return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
|
|
" else: # Match Frame 0 LAB\n",
|
|
" prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n",
|
|
" color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n",
|
|
" matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n",
|
|
" return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n",
|
|
"\n",
|
|
"\n",
|
|
"#\n",
|
|
"# Callback functions\n",
|
|
"#\n",
|
|
"class SamplerCallback(object):\n",
|
|
" # Creates the callback function to be passed into the samplers for each step\n",
|
|
" def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None,\n",
|
|
" verbose=False):\n",
|
|
" self.sampler_name = args.sampler\n",
|
|
" self.dynamic_threshold = args.dynamic_threshold\n",
|
|
" self.static_threshold = args.static_threshold\n",
|
|
" self.mask = mask\n",
|
|
" self.init_latent = init_latent \n",
|
|
" self.sigmas = sigmas\n",
|
|
" self.sampler = sampler\n",
|
|
" self.verbose = verbose\n",
|
|
"\n",
|
|
" self.batch_size = args.n_samples\n",
|
|
" self.save_sample_per_step = args.save_sample_per_step\n",
|
|
" self.show_sample_per_step = args.show_sample_per_step\n",
|
|
" self.paths_to_image_steps = [os.path.join( args.outdir, f\"{args.timestring}_{index:02}_{args.seed}\") for index in range(args.n_samples) ]\n",
|
|
"\n",
|
|
" if self.save_sample_per_step:\n",
|
|
" for path in self.paths_to_image_steps:\n",
|
|
" os.makedirs(path, exist_ok=True)\n",
|
|
"\n",
|
|
" self.step_index = 0\n",
|
|
"\n",
|
|
" self.noise = None\n",
|
|
" if init_latent is not None:\n",
|
|
" self.noise = torch.randn_like(init_latent, device=device)\n",
|
|
"\n",
|
|
" self.mask_schedule = None\n",
|
|
" if sigmas is not None and len(sigmas) > 0:\n",
|
|
" self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))\n",
|
|
" elif len(sigmas) == 0:\n",
|
|
" self.mask = None # no mask needed if no steps (usually happens because strength==1.0)\n",
|
|
"\n",
|
|
" if self.sampler_name in [\"plms\",\"ddim\"]: \n",
|
|
" if mask is not None:\n",
|
|
" assert sampler is not None, \"Callback function for stable-diffusion samplers requires sampler variable\"\n",
|
|
"\n",
|
|
" if self.sampler_name in [\"plms\",\"ddim\"]: \n",
|
|
" # Callback function formated for compvis latent diffusion samplers\n",
|
|
" self.callback = self.img_callback_\n",
|
|
" else: \n",
|
|
" # Default callback function uses k-diffusion sampler variables\n",
|
|
" self.callback = self.k_callback_\n",
|
|
"\n",
|
|
" self.verbose_print = print if verbose else lambda *args, **kwargs: None\n",
|
|
"\n",
|
|
" def view_sample_step(self, latents, path_name_modifier=''):\n",
|
|
" samples = model.decode_first_stage(latents)\n",
|
|
" if self.save_sample_per_step:\n",
|
|
" fname = f'{path_name_modifier}_{self.step_index:05}.png'\n",
|
|
" for i, sample in enumerate(samples):\n",
|
|
" sample = sample.double().cpu().add(1).div(2).clamp(0, 1)\n",
|
|
" sample = torch.tensor(np.array(sample))\n",
|
|
" grid = make_grid(sample, 4).cpu()\n",
|
|
" TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))\n",
|
|
" if self.show_sample_per_step:\n",
|
|
" print(path_name_modifier)\n",
|
|
" self.display_images(samples)\n",
|
|
" return\n",
|
|
"\n",
|
|
" def display_images(self, images):\n",
|
|
" images = images.double().cpu().add(1).div(2).clamp(0, 1)\n",
|
|
" images = torch.tensor(np.array(images))\n",
|
|
" grid = make_grid(images, 4).cpu()\n",
|
|
" display.display(TF.to_pil_image(grid))\n",
|
|
" return\n",
|
|
"\n",
|
|
" # The callback function is applied to the image at each step\n",
|
|
" def dynamic_thresholding_(self, img, threshold):\n",
|
|
" # Dynamic thresholding from Imagen paper (May 2022)\n",
|
|
" s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n",
|
|
" s = np.max(np.append(s,1.0))\n",
|
|
" torch.clamp_(img, -1*s, s)\n",
|
|
" torch.FloatTensor.div_(img, s)\n",
|
|
"\n",
|
|
" # Callback for samplers in the k-diffusion repo, called thus:\n",
|
|
" # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n",
|
|
" def k_callback_(self, args_dict):\n",
|
|
" self.step_index = args_dict['i']\n",
|
|
" if self.dynamic_threshold is not None:\n",
|
|
" self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)\n",
|
|
" if self.static_threshold is not None:\n",
|
|
" torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)\n",
|
|
" if self.mask is not None:\n",
|
|
" init_noise = self.init_latent + self.noise * args_dict['sigma']\n",
|
|
" is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )\n",
|
|
" new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)\n",
|
|
" args_dict['x'].copy_(new_img)\n",
|
|
"\n",
|
|
" self.view_sample_step(args_dict['denoised'], \"x0_pred\")\n",
|
|
"\n",
|
|
" # Callback for Compvis samplers\n",
|
|
" # Function that is called on the image (img) and step (i) at each step\n",
|
|
" def img_callback_(self, img, i):\n",
|
|
" self.step_index = i\n",
|
|
" # Thresholding functions\n",
|
|
" if self.dynamic_threshold is not None:\n",
|
|
" self.dynamic_thresholding_(img, self.dynamic_threshold)\n",
|
|
" if self.static_threshold is not None:\n",
|
|
" torch.clamp_(img, -1*self.static_threshold, self.static_threshold)\n",
|
|
" if self.mask is not None:\n",
|
|
" i_inv = len(self.sigmas) - i - 1\n",
|
|
" init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise)\n",
|
|
" is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )\n",
|
|
" new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)\n",
|
|
" img.copy_(new_img)\n",
|
|
"\n",
|
|
" self.view_sample_step(img, \"x\")\n",
|
|
"\n",
|
|
"def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
|
|
" sample = ((sample.astype(float) / 255.0) * 2) - 1\n",
|
|
" sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)\n",
|
|
" sample = torch.from_numpy(sample)\n",
|
|
" return sample\n",
|
|
"\n",
|
|
"def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:\n",
|
|
" sample_f32 = rearrange(sample.squeeze().cpu().numpy(), \"c h w -> h w c\").astype(np.float32)\n",
|
|
" sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)\n",
|
|
" sample_int8 = (sample_f32 * 255)\n",
|
|
" return sample_int8.astype(type)\n",
|
|
"\n",
|
|
"def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):\n",
|
|
" # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion \n",
|
|
" w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]\n",
|
|
"\n",
|
|
" aspect_ratio = float(w)/float(h)\n",
|
|
" near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov\n",
|
|
" persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device)\n",
|
|
" persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)\n",
|
|
"\n",
|
|
" # range of [-1,1] is important to torch grid_sample's padding handling\n",
|
|
" y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))\n",
|
|
" z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)\n",
|
|
" xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)\n",
|
|
"\n",
|
|
" xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]\n",
|
|
" xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]\n",
|
|
"\n",
|
|
" offset_xy = xyz_new_cam_xy - xyz_old_cam_xy\n",
|
|
" # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.\n",
|
|
" identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)\n",
|
|
" # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.\n",
|
|
" coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)\n",
|
|
" offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)\n",
|
|
"\n",
|
|
" image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)\n",
|
|
" new_image = torch.nn.functional.grid_sample(\n",
|
|
" image_tensor.add(1/512 - 0.0001).unsqueeze(0), \n",
|
|
" offset_coords_2d, \n",
|
|
" mode=anim_args.sampling_mode, \n",
|
|
" padding_mode=anim_args.padding_mode, \n",
|
|
" align_corners=False\n",
|
|
" )\n",
|
|
"\n",
|
|
" # convert back to cv2 style numpy array\n",
|
|
" result = rearrange(\n",
|
|
" new_image.squeeze().clamp(0,255), \n",
|
|
" 'c h w -> h w c'\n",
|
|
" ).cpu().numpy().astype(prev_img_cv2.dtype)\n",
|
|
" return result\n",
|
|
"\n",
|
|
"def generate(args, return_latent=False, return_sample=False, return_c=False):\n",
|
|
" seed_everything(args.seed)\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
"\n",
|
|
" sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model)\n",
|
|
" model_wrap = CompVisDenoiser(model)\n",
|
|
" batch_size = args.n_samples\n",
|
|
" prompt = args.prompt\n",
|
|
" assert prompt is not None\n",
|
|
" data = [batch_size * [prompt]]\n",
|
|
" precision_scope = autocast if args.precision == \"autocast\" else nullcontext\n",
|
|
"\n",
|
|
" init_latent = None\n",
|
|
" mask_image = None\n",
|
|
" init_image = None\n",
|
|
" if args.init_latent is not None:\n",
|
|
" init_latent = args.init_latent\n",
|
|
" elif args.init_sample is not None:\n",
|
|
" with precision_scope(\"cuda\"):\n",
|
|
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))\n",
|
|
" elif args.use_init and args.init_image != None and args.init_image != '':\n",
|
|
" init_image, mask_image = load_img(args.init_image, \n",
|
|
" shape=(args.W, args.H), \n",
|
|
" use_alpha_as_mask=args.use_alpha_as_mask)\n",
|
|
" init_image = init_image.to(device)\n",
|
|
" init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
|
|
" with precision_scope(\"cuda\"):\n",
|
|
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space \n",
|
|
"\n",
|
|
" if not args.use_init and args.strength > 0 and args.strength_0_no_init:\n",
|
|
" print(\"\\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.\")\n",
|
|
" print(\"If you want to force strength > 0 with no init, please set strength_0_no_init to False.\\n\")\n",
|
|
" args.strength = 0\n",
|
|
"\n",
|
|
" # Mask functions\n",
|
|
" if args.use_mask:\n",
|
|
" assert args.mask_file is not None or mask_image is not None, \"use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel\"\n",
|
|
" assert args.use_init, \"use_mask==True: use_init is required for a mask\"\n",
|
|
" assert init_latent is not None, \"use_mask==True: An latent init image is required for a mask\"\n",
|
|
"\n",
|
|
"\n",
|
|
" mask = prepare_mask(args.mask_file if mask_image is None else mask_image, \n",
|
|
" init_latent.shape, \n",
|
|
" args.mask_contrast_adjust, \n",
|
|
" args.mask_brightness_adjust)\n",
|
|
" \n",
|
|
" if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:\n",
|
|
" raise Warning(\"use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.\")\n",
|
|
" \n",
|
|
" mask = mask.to(device)\n",
|
|
" mask = repeat(mask, '1 ... -> b ...', b=batch_size)\n",
|
|
" else:\n",
|
|
" mask = None\n",
|
|
"\n",
|
|
" assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), \"Need an init image when use_mask == True and overlay_mask == True\"\n",
|
|
" \n",
|
|
" t_enc = int((1.0-args.strength) * args.steps)\n",
|
|
"\n",
|
|
" # Noise schedule for the k-diffusion samplers (used for masking)\n",
|
|
" k_sigmas = model_wrap.get_sigmas(args.steps)\n",
|
|
" k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]\n",
|
|
"\n",
|
|
" if args.sampler in ['plms','ddim']:\n",
|
|
" sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)\n",
|
|
"\n",
|
|
" callback = SamplerCallback(args=args,\n",
|
|
" mask=mask, \n",
|
|
" init_latent=init_latent,\n",
|
|
" sigmas=k_sigmas,\n",
|
|
" sampler=sampler,\n",
|
|
" verbose=False).callback \n",
|
|
"\n",
|
|
" results = []\n",
|
|
" with torch.no_grad():\n",
|
|
" with precision_scope(\"cuda\"):\n",
|
|
" with model.ema_scope():\n",
|
|
" for prompts in data:\n",
|
|
" uc = None\n",
|
|
" if args.scale != 1.0:\n",
|
|
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
|
|
" if isinstance(prompts, tuple):\n",
|
|
" prompts = list(prompts)\n",
|
|
" c = model.get_learned_conditioning(prompts)\n",
|
|
"\n",
|
|
" if args.init_c != None:\n",
|
|
" c = args.init_c\n",
|
|
"\n",
|
|
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
|
|
" samples = sampler_fn(\n",
|
|
" c=c, \n",
|
|
" uc=uc, \n",
|
|
" args=args, \n",
|
|
" model_wrap=model_wrap, \n",
|
|
" init_latent=init_latent, \n",
|
|
" t_enc=t_enc, \n",
|
|
" device=device, \n",
|
|
" cb=callback)\n",
|
|
" else:\n",
|
|
" # args.sampler == 'plms' or args.sampler == 'ddim':\n",
|
|
" if init_latent is not None and args.strength > 0:\n",
|
|
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
|
|
" else:\n",
|
|
" z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
|
|
" if args.sampler == 'ddim':\n",
|
|
" samples = sampler.decode(z_enc, \n",
|
|
" c, \n",
|
|
" t_enc, \n",
|
|
" unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,\n",
|
|
" img_callback=callback)\n",
|
|
" elif args.sampler == 'plms': # no \"decode\" function in plms, so use \"sample\"\n",
|
|
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
|
|
" samples, _ = sampler.sample(S=args.steps,\n",
|
|
" conditioning=c,\n",
|
|
" batch_size=args.n_samples,\n",
|
|
" shape=shape,\n",
|
|
" verbose=False,\n",
|
|
" unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,\n",
|
|
" eta=args.ddim_eta,\n",
|
|
" x_T=z_enc,\n",
|
|
" img_callback=callback)\n",
|
|
" else:\n",
|
|
" raise Exception(f\"Sampler {args.sampler} not recognised.\")\n",
|
|
"\n",
|
|
" \n",
|
|
" if return_latent:\n",
|
|
" results.append(samples.clone())\n",
|
|
"\n",
|
|
" x_samples = model.decode_first_stage(samples)\n",
|
|
"\n",
|
|
" if args.use_mask and args.overlay_mask:\n",
|
|
" # Overlay the masked image after the image is generated\n",
|
|
" if args.init_sample is not None:\n",
|
|
" img_original = args.init_sample\n",
|
|
" elif init_image is not None:\n",
|
|
" img_original = init_image\n",
|
|
" else:\n",
|
|
" raise Exception(\"Cannot overlay the masked image without an init image to overlay\")\n",
|
|
"\n",
|
|
" mask_fullres = prepare_mask(args.mask_file if mask_image is None else mask_image, \n",
|
|
" img_original.shape, \n",
|
|
" args.mask_contrast_adjust, \n",
|
|
" args.mask_brightness_adjust)\n",
|
|
" mask_fullres = mask_fullres[:,:3,:,:]\n",
|
|
" mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=batch_size)\n",
|
|
"\n",
|
|
" mask_fullres[mask_fullres < mask_fullres.max()] = 0\n",
|
|
" mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur)\n",
|
|
" mask_fullres = torch.Tensor(mask_fullres).to(device)\n",
|
|
"\n",
|
|
" x_samples = img_original * mask_fullres + x_samples * ((mask_fullres * -1.0) + 1)\n",
|
|
"\n",
|
|
"\n",
|
|
" if return_sample:\n",
|
|
" results.append(x_samples.clone())\n",
|
|
"\n",
|
|
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
|
"\n",
|
|
" if return_c:\n",
|
|
" results.append(c.clone())\n",
|
|
"\n",
|
|
" for x_sample in x_samples:\n",
|
|
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
|
|
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
|
|
" results.append(image)\n",
|
|
" return results"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "CIUJ7lWI4v53"
|
|
},
|
|
"source": [
|
|
"#@markdown **Select and Load Model**\n",
|
|
"\n",
|
|
"model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
|
|
"model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n",
|
|
"custom_config_path = \"\" #@param {type:\"string\"}\n",
|
|
"custom_checkpoint_path = \"\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"load_on_run_all = True #@param {type: 'boolean'}\n",
|
|
"half_precision = True # check\n",
|
|
"check_sha256 = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"model_map = {\n",
|
|
" \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n",
|
|
" \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n",
|
|
" \"sd-v1-3-full-ema.ckpt\": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},\n",
|
|
" \"sd-v1-3.ckpt\": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},\n",
|
|
" \"sd-v1-2-full-ema.ckpt\": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},\n",
|
|
" \"sd-v1-2.ckpt\": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},\n",
|
|
" \"sd-v1-1-full-ema.ckpt\": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},\n",
|
|
" \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n",
|
|
"}\n",
|
|
"\n",
|
|
"# config path\n",
|
|
"ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n",
|
|
"if os.path.exists(ckpt_config_path):\n",
|
|
" print(f\"{ckpt_config_path} exists\")\n",
|
|
"else:\n",
|
|
" ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n",
|
|
"print(f\"Using config: {ckpt_config_path}\")\n",
|
|
"\n",
|
|
"# checkpoint path or download\n",
|
|
"ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n",
|
|
"ckpt_valid = True\n",
|
|
"if os.path.exists(ckpt_path):\n",
|
|
" print(f\"{ckpt_path} exists\")\n",
|
|
"else:\n",
|
|
" print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n",
|
|
" ckpt_valid = False\n",
|
|
"\n",
|
|
"if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n",
|
|
" import hashlib\n",
|
|
" print(\"\\n...checking sha256\")\n",
|
|
" with open(ckpt_path, \"rb\") as f:\n",
|
|
" bytes = f.read() \n",
|
|
" hash = hashlib.sha256(bytes).hexdigest()\n",
|
|
" del bytes\n",
|
|
" if model_map[model_checkpoint][\"sha256\"] == hash:\n",
|
|
" print(\"hash is correct\\n\")\n",
|
|
" else:\n",
|
|
" print(\"hash in not correct\\n\")\n",
|
|
" ckpt_valid = False\n",
|
|
"\n",
|
|
"if ckpt_valid:\n",
|
|
" print(f\"Using ckpt: {ckpt_path}\")\n",
|
|
"\n",
|
|
"def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n",
|
|
" map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n",
|
|
" print(f\"Loading model from {ckpt}\")\n",
|
|
" pl_sd = torch.load(ckpt, map_location=map_location)\n",
|
|
" if \"global_step\" in pl_sd:\n",
|
|
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
|
|
" sd = pl_sd[\"state_dict\"]\n",
|
|
" model = instantiate_from_config(config.model)\n",
|
|
" m, u = model.load_state_dict(sd, strict=False)\n",
|
|
" if len(m) > 0 and verbose:\n",
|
|
" print(\"missing keys:\")\n",
|
|
" print(m)\n",
|
|
" if len(u) > 0 and verbose:\n",
|
|
" print(\"unexpected keys:\")\n",
|
|
" print(u)\n",
|
|
"\n",
|
|
" if half_precision:\n",
|
|
" model = model.half().to(device)\n",
|
|
" else:\n",
|
|
" model = model.to(device)\n",
|
|
" model.eval()\n",
|
|
" return model\n",
|
|
"\n",
|
|
"if load_on_run_all and ckpt_valid:\n",
|
|
" local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n",
|
|
" model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n",
|
|
" device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
|
" model = model.to(device)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ov3r4RD1tzsT"
|
|
},
|
|
"source": [
|
|
"# Settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0j7rgxvLvfay"
|
|
},
|
|
"source": [
|
|
"### Animation Settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "8HJN2TE3vh-J"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"def DeforumAnimArgs():\n",
|
|
"\n",
|
|
" #@markdown ####**Animation:**\n",
|
|
" animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}\n",
|
|
" max_frames = 1000 #@param {type:\"number\"}\n",
|
|
" border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}\n",
|
|
"\n",
|
|
" #@markdown ####**Motion Parameters:**\n",
|
|
" angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" zoom = \"0:(1.04)\"#@param {type:\"string\"}\n",
|
|
" translation_x = \"0:(10*sin(2*3.14*t/10))\"#@param {type:\"string\"}\n",
|
|
" translation_y = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" translation_z = \"0:(10)\"#@param {type:\"string\"}\n",
|
|
" rotation_3d_x = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" rotation_3d_y = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" rotation_3d_z = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" flip_2d_perspective = False #@param {type:\"boolean\"}\n",
|
|
" perspective_flip_theta = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" perspective_flip_phi = \"0:(t%15)\"#@param {type:\"string\"}\n",
|
|
" perspective_flip_gamma = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" perspective_flip_fv = \"0:(53)\"#@param {type:\"string\"}\n",
|
|
" noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n",
|
|
" strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n",
|
|
" contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Coherence:**\n",
|
|
" color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
|
|
" diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}\n",
|
|
"\n",
|
|
" #@markdown ####**3D Depth Warping:**\n",
|
|
" use_depth_warping = True #@param {type:\"boolean\"}\n",
|
|
" midas_weight = 0.3#@param {type:\"number\"}\n",
|
|
" near_plane = 200\n",
|
|
" far_plane = 10000\n",
|
|
" fov = 40#@param {type:\"number\"}\n",
|
|
" padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}\n",
|
|
" sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}\n",
|
|
" save_depth_maps = False #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Video Input:**\n",
|
|
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
|
" extract_nth_frame = 1#@param {type:\"number\"}\n",
|
|
" overwrite_extracted_frames = True #@param {type:\"boolean\"}\n",
|
|
" use_mask_video = False #@param {type:\"boolean\"}\n",
|
|
" video_mask_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Interpolation:**\n",
|
|
" interpolate_key_frames = False #@param {type:\"boolean\"}\n",
|
|
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
|
|
" \n",
|
|
" #@markdown ####**Resume Animation:**\n",
|
|
" resume_from_timestring = False #@param {type:\"boolean\"}\n",
|
|
" resume_timestring = \"20220829210106\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
" return locals()\n",
|
|
"\n",
|
|
"class DeformAnimKeys():\n",
|
|
" def __init__(self, anim_args):\n",
|
|
" self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames)\n",
|
|
" self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames)\n",
|
|
" self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames)\n",
|
|
" self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames)\n",
|
|
" self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames)\n",
|
|
" self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames)\n",
|
|
" self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames)\n",
|
|
" self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames)\n",
|
|
" self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames)\n",
|
|
" self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames)\n",
|
|
" self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames)\n",
|
|
" self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames)\n",
|
|
" self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames)\n",
|
|
" self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames)\n",
|
|
" self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames)\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'):\n",
|
|
" import numexpr\n",
|
|
" import re\n",
|
|
" float_pattern = r'^(?=.)([+-]?([0-9]*)(\\.([0-9]+))?)$'\n",
|
|
" key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
" \n",
|
|
" for i in range(0, max_frames):\n",
|
|
" if i in key_frames:\n",
|
|
" value = key_frames[i]\n",
|
|
" value_is_number = re.match(float_pattern, value)\n",
|
|
" # if it's only a number, leave the rest for the default interpolation\n",
|
|
" if value_is_number:\n",
|
|
" t = i\n",
|
|
" key_frame_series[i] = value\n",
|
|
" if not value_is_number:\n",
|
|
" t = i\n",
|
|
" key_frame_series[i] = numexpr.evaluate(value)\n",
|
|
" key_frame_series = key_frame_series.astype(float)\n",
|
|
" \n",
|
|
" if interp_method == 'Cubic' and len(key_frames.items()) <= 3:\n",
|
|
" interp_method = 'Quadratic' \n",
|
|
" if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
|
|
" interp_method = 'Linear'\n",
|
|
" \n",
|
|
" key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
|
|
" key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
|
|
" key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both')\n",
|
|
" if integer:\n",
|
|
" return key_frame_series.astype(int)\n",
|
|
" return key_frame_series\n",
|
|
"\n",
|
|
"def parse_key_frames(string, prompt_parser=None):\n",
|
|
" import re\n",
|
|
" # because math functions (i.e. sin(t)) can utilize brackets \n",
|
|
" # it extracts the value in form of some stuff\n",
|
|
" # which has previously been enclosed with brackets and\n",
|
|
" # with a comma or end of line existing after the closing one\n",
|
|
" pattern = r'((?P<frame>[0-9]+):[\\s]*\\((?P<param>[\\S\\s]*?)\\)([,][\\s]?|[\\s]?$))'\n",
|
|
" frames = dict()\n",
|
|
" for match_object in re.finditer(pattern, string):\n",
|
|
" frame = int(match_object.groupdict()['frame'])\n",
|
|
" param = match_object.groupdict()['param']\n",
|
|
" if prompt_parser:\n",
|
|
" frames[frame] = prompt_parser(param)\n",
|
|
" else:\n",
|
|
" frames[frame] = param\n",
|
|
" if frames == {} and len(string) != 0:\n",
|
|
" raise RuntimeError('Key Frame string not correctly formatted')\n",
|
|
" return frames"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "63UOJvU3xdPS"
|
|
},
|
|
"source": [
|
|
"### Prompts\n",
|
|
"`animation_mode: None` batches on list of *prompts*. `animation_mode: 2D` uses *animation_prompts* key frame sequence"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2ujwkGZTcGev"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"prompts = [\n",
|
|
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
|
|
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
|
|
" #\"the third prompt I don't want it I commented it with an\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"animation_prompts = {\n",
|
|
" 0: \"a beautiful apple, trending on Artstation\",\n",
|
|
" 20: \"a beautiful banana, trending on Artstation\",\n",
|
|
" 30: \"a beautiful coconut, trending on Artstation\",\n",
|
|
" 40: \"a beautiful durian, trending on Artstation\",\n",
|
|
"}"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "s8RAo2zI-vQm"
|
|
},
|
|
"source": [
|
|
"# Run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "qH74gBWDd2oq",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"override_settings_with_file = False #@param {type:\"boolean\"}\n",
|
|
"custom_settings_file = \"/content/drive/MyDrive/Settings.txt\"#@param {type:\"string\"}\n",
|
|
"\n",
|
|
"def DeforumArgs():\n",
|
|
" #@markdown **Image Settings**\n",
|
|
" W = 512 #@param\n",
|
|
" H = 512 #@param\n",
|
|
" W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n",
|
|
"\n",
|
|
" #@markdown **Sampling Settings**\n",
|
|
" seed = -1 #@param\n",
|
|
" sampler = 'klms' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n",
|
|
" steps = 50 #@param\n",
|
|
" scale = 7 #@param\n",
|
|
" ddim_eta = 0.0 #@param\n",
|
|
" dynamic_threshold = None\n",
|
|
" static_threshold = None \n",
|
|
"\n",
|
|
" #@markdown **Save & Display Settings**\n",
|
|
" save_samples = True #@param {type:\"boolean\"}\n",
|
|
" save_settings = True #@param {type:\"boolean\"}\n",
|
|
" display_samples = True #@param {type:\"boolean\"}\n",
|
|
" save_sample_per_step = False #@param {type:\"boolean\"}\n",
|
|
" show_sample_per_step = False #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
" #@markdown **Batch Settings**\n",
|
|
" n_batch = 1 #@param\n",
|
|
" batch_name = \"StableFun\" #@param {type:\"string\"}\n",
|
|
" filename_format = \"{timestring}_{index}_{prompt}.png\" #@param [\"{timestring}_{index}_{seed}.png\",\"{timestring}_{index}_{prompt}.png\"]\n",
|
|
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
|
|
" make_grid = False #@param {type:\"boolean\"}\n",
|
|
" grid_rows = 2 #@param \n",
|
|
" outdir = get_output_folder(output_path, batch_name)\n",
|
|
"\n",
|
|
" #@markdown **Init Settings**\n",
|
|
" use_init = False #@param {type:\"boolean\"}\n",
|
|
" strength = 0.0 #@param {type:\"number\"}\n",
|
|
" strength_0_no_init = True # Set the strength to 0 automatically when no init image is used\n",
|
|
" init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n",
|
|
" # Whiter areas of the mask are areas that change more\n",
|
|
" use_mask = False #@param {type:\"boolean\"}\n",
|
|
" use_alpha_as_mask = False # use the alpha channel of the init image as the mask\n",
|
|
" mask_file = \"https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg\" #@param {type:\"string\"}\n",
|
|
" invert_mask = False #@param {type:\"boolean\"}\n",
|
|
" # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.\n",
|
|
" mask_brightness_adjust = 1.0 #@param {type:\"number\"}\n",
|
|
" mask_contrast_adjust = 1.0 #@param {type:\"number\"}\n",
|
|
" # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding\n",
|
|
" overlay_mask = True #@param {type:\"boolean\"}\n",
|
|
" # Blur edges of final overlay mask, if used. Minimum = 0 (no blur)\n",
|
|
" mask_overlay_blur = 5 #@param {type:\"number\"}\n",
|
|
"\n",
|
|
" n_samples = 1 # doesnt do anything\n",
|
|
" precision = 'autocast' \n",
|
|
" C = 4\n",
|
|
" f = 8\n",
|
|
"\n",
|
|
" prompt = \"\"\n",
|
|
" timestring = \"\"\n",
|
|
" init_latent = None\n",
|
|
" init_sample = None\n",
|
|
" init_c = None\n",
|
|
"\n",
|
|
" return locals()\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def next_seed(args):\n",
|
|
" if args.seed_behavior == 'iter':\n",
|
|
" args.seed += 1\n",
|
|
" elif args.seed_behavior == 'fixed':\n",
|
|
" pass # always keep seed the same\n",
|
|
" else:\n",
|
|
" args.seed = random.randint(0, 2**32 - 1)\n",
|
|
" return args.seed\n",
|
|
"\n",
|
|
"def render_image_batch(args):\n",
|
|
" args.prompts = {k: f\"{v:05d}\" for v, k in enumerate(prompts)}\n",
|
|
" \n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" if args.save_settings or args.save_samples:\n",
|
|
" print(f\"Saving to {os.path.join(args.outdir, args.timestring)}_*\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" if args.save_settings:\n",
|
|
" filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n",
|
|
"\n",
|
|
" index = 0\n",
|
|
" \n",
|
|
" # function for init image batching\n",
|
|
" init_array = []\n",
|
|
" if args.use_init:\n",
|
|
" if args.init_image == \"\":\n",
|
|
" raise FileNotFoundError(\"No path was given for init_image\")\n",
|
|
" if args.init_image.startswith('http://') or args.init_image.startswith('https://'):\n",
|
|
" init_array.append(args.init_image)\n",
|
|
" elif not os.path.isfile(args.init_image):\n",
|
|
" if args.init_image[-1] != \"/\": # avoids path error by adding / to end if not there\n",
|
|
" args.init_image += \"/\" \n",
|
|
" for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array\n",
|
|
" if image.split(\".\")[-1] in (\"png\", \"jpg\", \"jpeg\"):\n",
|
|
" init_array.append(args.init_image + image)\n",
|
|
" else:\n",
|
|
" init_array.append(args.init_image)\n",
|
|
" else:\n",
|
|
" init_array = [\"\"]\n",
|
|
"\n",
|
|
" # when doing large batches don't flood browser with images\n",
|
|
" clear_between_batches = args.n_batch >= 32\n",
|
|
"\n",
|
|
" for iprompt, prompt in enumerate(prompts): \n",
|
|
" args.prompt = prompt\n",
|
|
" print(f\"Prompt {iprompt+1} of {len(prompts)}\")\n",
|
|
" print(f\"{args.prompt}\")\n",
|
|
"\n",
|
|
" all_images = []\n",
|
|
"\n",
|
|
" for batch_index in range(args.n_batch):\n",
|
|
" if clear_between_batches and batch_index % 32 == 0: \n",
|
|
" display.clear_output(wait=True) \n",
|
|
" print(f\"Batch {batch_index+1} of {args.n_batch}\")\n",
|
|
" \n",
|
|
" for image in init_array: # iterates the init images\n",
|
|
" args.init_image = image\n",
|
|
" results = generate(args)\n",
|
|
" for image in results:\n",
|
|
" if args.make_grid:\n",
|
|
" all_images.append(T.functional.pil_to_tensor(image))\n",
|
|
" if args.save_samples:\n",
|
|
" if args.filename_format == \"{timestring}_{index}_{prompt}.png\":\n",
|
|
" filename = f\"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png\"\n",
|
|
" else:\n",
|
|
" filename = f\"{args.timestring}_{index:05}_{args.seed}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" if args.display_samples:\n",
|
|
" display.display(image)\n",
|
|
" index += 1\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" #print(len(all_images))\n",
|
|
" if args.make_grid:\n",
|
|
" grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n",
|
|
" grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
|
|
" filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n",
|
|
" grid_image = Image.fromarray(grid.astype(np.uint8))\n",
|
|
" grid_image.save(os.path.join(args.outdir, filename))\n",
|
|
" display.clear_output(wait=True) \n",
|
|
" display.display(grid_image)\n",
|
|
"\n",
|
|
"\n",
|
|
"def render_animation(args, anim_args):\n",
|
|
" # animations use key framed prompts\n",
|
|
" args.prompts = animation_prompts\n",
|
|
"\n",
|
|
" # expand key frame strings to values\n",
|
|
" keys = DeformAnimKeys(anim_args)\n",
|
|
"\n",
|
|
" # resume animation\n",
|
|
" start_frame = 0\n",
|
|
" if anim_args.resume_from_timestring:\n",
|
|
" for tmp in os.listdir(args.outdir):\n",
|
|
" if tmp.split(\"_\")[0] == anim_args.resume_timestring:\n",
|
|
" start_frame += 1\n",
|
|
" start_frame = start_frame - 1\n",
|
|
"\n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" print(f\"Saving animation frames to {args.outdir}\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
|
|
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
|
|
" \n",
|
|
" # resume from timestring\n",
|
|
" if anim_args.resume_from_timestring:\n",
|
|
" args.timestring = anim_args.resume_timestring\n",
|
|
"\n",
|
|
" # expand prompts out to per-frame\n",
|
|
" prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n",
|
|
" for i, prompt in animation_prompts.items():\n",
|
|
" prompt_series[i] = prompt\n",
|
|
" prompt_series = prompt_series.ffill().bfill()\n",
|
|
"\n",
|
|
" # check for video inits\n",
|
|
" using_vid_init = anim_args.animation_mode == 'Video Input'\n",
|
|
"\n",
|
|
" # load depth model for 3D\n",
|
|
" predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps\n",
|
|
" if predict_depths:\n",
|
|
" depth_model = DepthModel(device)\n",
|
|
" depth_model.load_midas(models_path)\n",
|
|
" if anim_args.midas_weight < 1.0:\n",
|
|
" depth_model.load_adabins()\n",
|
|
" else:\n",
|
|
" depth_model = None\n",
|
|
" anim_args.save_depth_maps = False\n",
|
|
"\n",
|
|
" # state for interpolating between diffusion steps\n",
|
|
" turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)\n",
|
|
" turbo_prev_image, turbo_prev_frame_idx = None, 0\n",
|
|
" turbo_next_image, turbo_next_frame_idx = None, 0\n",
|
|
"\n",
|
|
" # resume animation\n",
|
|
" prev_sample = None\n",
|
|
" color_match_sample = None\n",
|
|
" if anim_args.resume_from_timestring:\n",
|
|
" last_frame = start_frame-1\n",
|
|
" if turbo_steps > 1:\n",
|
|
" last_frame -= last_frame%turbo_steps\n",
|
|
" path = os.path.join(args.outdir,f\"{args.timestring}_{last_frame:05}.png\")\n",
|
|
" img = cv2.imread(path)\n",
|
|
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
|
" prev_sample = sample_from_cv2(img)\n",
|
|
" if anim_args.color_coherence != 'None':\n",
|
|
" color_match_sample = img\n",
|
|
" if turbo_steps > 1:\n",
|
|
" turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame\n",
|
|
" turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx\n",
|
|
" start_frame = last_frame+turbo_steps\n",
|
|
"\n",
|
|
" args.n_samples = 1\n",
|
|
" frame_idx = start_frame\n",
|
|
" while frame_idx < anim_args.max_frames:\n",
|
|
" print(f\"Rendering animation frame {frame_idx} of {anim_args.max_frames}\")\n",
|
|
" noise = keys.noise_schedule_series[frame_idx]\n",
|
|
" strength = keys.strength_schedule_series[frame_idx]\n",
|
|
" contrast = keys.contrast_schedule_series[frame_idx]\n",
|
|
" depth = None\n",
|
|
" \n",
|
|
" # emit in-between frames\n",
|
|
" if turbo_steps > 1:\n",
|
|
" tween_frame_start_idx = max(0, frame_idx-turbo_steps)\n",
|
|
" for tween_frame_idx in range(tween_frame_start_idx, frame_idx):\n",
|
|
" tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)\n",
|
|
" print(f\" creating in between frame {tween_frame_idx} tween:{tween:0.2f}\")\n",
|
|
"\n",
|
|
" advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx\n",
|
|
" advance_next = tween_frame_idx > turbo_next_frame_idx\n",
|
|
"\n",
|
|
" if depth_model is not None:\n",
|
|
" assert(turbo_next_image is not None)\n",
|
|
" depth = depth_model.predict(turbo_next_image, anim_args)\n",
|
|
"\n",
|
|
" if anim_args.animation_mode == '2D':\n",
|
|
" if advance_prev:\n",
|
|
" turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx)\n",
|
|
" if advance_next:\n",
|
|
" turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx)\n",
|
|
" else: # '3D'\n",
|
|
" if advance_prev:\n",
|
|
" turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx)\n",
|
|
" if advance_next:\n",
|
|
" turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx)\n",
|
|
" turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx\n",
|
|
"\n",
|
|
" if turbo_prev_image is not None and tween < 1.0:\n",
|
|
" img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween\n",
|
|
" else:\n",
|
|
" img = turbo_next_image\n",
|
|
"\n",
|
|
" filename = f\"{args.timestring}_{tween_frame_idx:05}.png\"\n",
|
|
" cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))\n",
|
|
" if anim_args.save_depth_maps:\n",
|
|
" depth_model.save(os.path.join(args.outdir, f\"{args.timestring}_depth_{tween_frame_idx:05}.png\"), depth)\n",
|
|
" if turbo_next_image is not None:\n",
|
|
" prev_sample = sample_from_cv2(turbo_next_image)\n",
|
|
"\n",
|
|
" # apply transforms to previous frame\n",
|
|
" if prev_sample is not None:\n",
|
|
" if anim_args.animation_mode == '2D':\n",
|
|
" prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx)\n",
|
|
" else: # '3D'\n",
|
|
" prev_img_cv2 = sample_to_cv2(prev_sample)\n",
|
|
" depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None\n",
|
|
" prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx)\n",
|
|
"\n",
|
|
" # apply color matching\n",
|
|
" if anim_args.color_coherence != 'None':\n",
|
|
" if color_match_sample is None:\n",
|
|
" color_match_sample = prev_img.copy()\n",
|
|
" else:\n",
|
|
" prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n",
|
|
"\n",
|
|
" # apply scaling\n",
|
|
" contrast_sample = prev_img * contrast\n",
|
|
" # apply frame noising\n",
|
|
" noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)\n",
|
|
"\n",
|
|
" # use transformed previous frame as init for current\n",
|
|
" args.use_init = True\n",
|
|
" if half_precision:\n",
|
|
" args.init_sample = noised_sample.half().to(device)\n",
|
|
" else:\n",
|
|
" args.init_sample = noised_sample.to(device)\n",
|
|
" args.strength = max(0.0, min(1.0, strength))\n",
|
|
"\n",
|
|
" # grab prompt for current frame\n",
|
|
" args.prompt = prompt_series[frame_idx]\n",
|
|
" print(f\"{args.prompt} {args.seed}\")\n",
|
|
" if not using_vid_init:\n",
|
|
" print(f\"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}\")\n",
|
|
" print(f\"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}\")\n",
|
|
" print(f\"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}\")\n",
|
|
"\n",
|
|
" # grab init image for current frame\n",
|
|
" if using_vid_init:\n",
|
|
" init_frame = os.path.join(args.outdir, 'inputframes', f\"{frame_idx+1:05}.jpg\") \n",
|
|
" print(f\"Using video init frame {init_frame}\")\n",
|
|
" args.init_image = init_frame\n",
|
|
" if anim_args.use_mask_video:\n",
|
|
" mask_frame = os.path.join(args.outdir, 'maskframes', f\"{frame_idx+1:05}.jpg\")\n",
|
|
" args.mask_file = mask_frame\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" sample, image = generate(args, return_latent=False, return_sample=True)\n",
|
|
" if not using_vid_init:\n",
|
|
" prev_sample = sample\n",
|
|
"\n",
|
|
" if turbo_steps > 1:\n",
|
|
" turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx\n",
|
|
" turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx\n",
|
|
" frame_idx += turbo_steps\n",
|
|
" else: \n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" if anim_args.save_depth_maps:\n",
|
|
" if depth is None:\n",
|
|
" depth = depth_model.predict(sample_to_cv2(sample), anim_args)\n",
|
|
" depth_model.save(os.path.join(args.outdir, f\"{args.timestring}_depth_{frame_idx:05}.png\"), depth)\n",
|
|
" frame_idx += 1\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
"\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
"def vid2frames(video_path, frames_path, n=1, overwrite=True): \n",
|
|
" if not os.path.exists(frames_path) or overwrite: \n",
|
|
" try:\n",
|
|
" for f in pathlib.Path(video_in_frame_path).glob('*.jpg'):\n",
|
|
" f.unlink()\n",
|
|
" except:\n",
|
|
" pass\n",
|
|
" assert os.path.exists(video_path), f\"Video input {video_path} does not exist\"\n",
|
|
" \n",
|
|
" vidcap = cv2.VideoCapture(video_path)\n",
|
|
" success,image = vidcap.read()\n",
|
|
" count = 0\n",
|
|
" t=1\n",
|
|
" success = True\n",
|
|
" while success:\n",
|
|
" if count % n == 0:\n",
|
|
" cv2.imwrite(frames_path + os.path.sep + f\"{t:05}.jpg\" , image) # save frame as JPEG file\n",
|
|
" t += 1\n",
|
|
" success,image = vidcap.read()\n",
|
|
" count += 1\n",
|
|
" print(\"Converted %d frames\" % count)\n",
|
|
" else: print(\"Frames already unpacked\")\n",
|
|
"\n",
|
|
"def render_input_video(args, anim_args):\n",
|
|
" # create a folder for the video input frames to live in\n",
|
|
" video_in_frame_path = os.path.join(args.outdir, 'inputframes') \n",
|
|
" os.makedirs(video_in_frame_path, exist_ok=True)\n",
|
|
" \n",
|
|
" # save the video frames from input video\n",
|
|
" print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...\")\n",
|
|
" vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)\n",
|
|
"\n",
|
|
" # determine max frames from length of input frames\n",
|
|
" anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])\n",
|
|
" args.use_init = True\n",
|
|
" print(f\"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}\")\n",
|
|
"\n",
|
|
" if anim_args.use_mask_video:\n",
|
|
" # create a folder for the mask video input frames to live in\n",
|
|
" mask_in_frame_path = os.path.join(args.outdir, 'maskframes') \n",
|
|
" os.makedirs(mask_in_frame_path, exist_ok=True)\n",
|
|
"\n",
|
|
" # save the video frames from mask video\n",
|
|
" print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...\")\n",
|
|
" vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)\n",
|
|
" args.use_mask = True\n",
|
|
" args.overlay_mask = True\n",
|
|
"\n",
|
|
" render_animation(args, anim_args)\n",
|
|
"\n",
|
|
"def render_interpolation(args, anim_args):\n",
|
|
" # animations use key framed prompts\n",
|
|
" args.prompts = animation_prompts\n",
|
|
"\n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" print(f\"Saving animation frames to {args.outdir}\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
|
|
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
|
|
" \n",
|
|
" # Interpolation Settings\n",
|
|
" args.n_samples = 1\n",
|
|
" args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available\n",
|
|
" prompts_c_s = [] # cache all the text embeddings\n",
|
|
"\n",
|
|
" print(f\"Preparing for interpolation of the following...\")\n",
|
|
"\n",
|
|
" for i, prompt in animation_prompts.items():\n",
|
|
" args.prompt = prompt\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args, return_c=True)\n",
|
|
" c, image = results[0], results[1]\n",
|
|
" prompts_c_s.append(c) \n",
|
|
" \n",
|
|
" # display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
" \n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" print(f\"Interpolation start...\")\n",
|
|
"\n",
|
|
" frame_idx = 0\n",
|
|
"\n",
|
|
" if anim_args.interpolate_key_frames:\n",
|
|
" for i in range(len(prompts_c_s)-1):\n",
|
|
" dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n",
|
|
" if dist_frames <= 0:\n",
|
|
" print(\"key frames duplicated or reversed. interpolation skipped.\")\n",
|
|
" return\n",
|
|
" else:\n",
|
|
" for j in range(dist_frames):\n",
|
|
" # interpolate the text embedding\n",
|
|
" prompt1_c = prompts_c_s[i]\n",
|
|
" prompt2_c = prompts_c_s[i+1] \n",
|
|
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args)\n",
|
|
" image = results[0]\n",
|
|
"\n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" frame_idx += 1\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
"\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" else:\n",
|
|
" for i in range(len(prompts_c_s)-1):\n",
|
|
" for j in range(anim_args.interpolate_x_frames+1):\n",
|
|
" # interpolate the text embedding\n",
|
|
" prompt1_c = prompts_c_s[i]\n",
|
|
" prompt2_c = prompts_c_s[i+1] \n",
|
|
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args)\n",
|
|
" image = results[0]\n",
|
|
"\n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" frame_idx += 1\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
"\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" # generate the last prompt\n",
|
|
" args.init_c = prompts_c_s[-1]\n",
|
|
" results = generate(args)\n",
|
|
" image = results[0]\n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" #clear init_c\n",
|
|
" args.init_c = None\n",
|
|
"\n",
|
|
"\n",
|
|
"args_dict = DeforumArgs()\n",
|
|
"anim_args_dict = DeforumAnimArgs()\n",
|
|
"\n",
|
|
"if override_settings_with_file:\n",
|
|
" print(f\"reading custom settings from {custom_settings_file}\")\n",
|
|
" if not os.path.isfile(custom_settings_file):\n",
|
|
" print('The custom settings file does not exist. The in-notebook settings will be used instead')\n",
|
|
" else:\n",
|
|
" with open(custom_settings_file, \"r\") as f:\n",
|
|
" jdata = json.loads(f.read())\n",
|
|
" animation_prompts = jdata[\"prompts\"]\n",
|
|
" for i, k in enumerate(args_dict):\n",
|
|
" if k in jdata:\n",
|
|
" args_dict[k] = jdata[k]\n",
|
|
" else:\n",
|
|
" print(f\"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}\")\n",
|
|
" for i, k in enumerate(anim_args_dict):\n",
|
|
" if k in jdata:\n",
|
|
" anim_args_dict[k] = jdata[k]\n",
|
|
" else:\n",
|
|
" print(f\"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}\")\n",
|
|
" print(args_dict)\n",
|
|
" print(anim_args_dict)\n",
|
|
"\n",
|
|
"args = SimpleNamespace(**args_dict)\n",
|
|
"anim_args = SimpleNamespace(**anim_args_dict)\n",
|
|
"\n",
|
|
"args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
|
|
"args.strength = max(0.0, min(1.0, args.strength))\n",
|
|
"\n",
|
|
"if args.seed == -1:\n",
|
|
" args.seed = random.randint(0, 2**32 - 1)\n",
|
|
"if not args.use_init:\n",
|
|
" args.init_image = None\n",
|
|
"if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n",
|
|
" print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n",
|
|
" args.sampler = 'klms'\n",
|
|
"if args.sampler != 'ddim':\n",
|
|
" args.ddim_eta = 0\n",
|
|
"\n",
|
|
"if anim_args.animation_mode == 'None':\n",
|
|
" anim_args.max_frames = 1\n",
|
|
"elif anim_args.animation_mode == 'Video Input':\n",
|
|
" args.use_init = True\n",
|
|
"\n",
|
|
"# clean up unused memory\n",
|
|
"gc.collect()\n",
|
|
"torch.cuda.empty_cache()\n",
|
|
"\n",
|
|
"# dispatch to appropriate renderer\n",
|
|
"if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':\n",
|
|
" render_animation(args, anim_args)\n",
|
|
"elif anim_args.animation_mode == 'Video Input':\n",
|
|
" render_input_video(args, anim_args)\n",
|
|
"elif anim_args.animation_mode == 'Interpolation':\n",
|
|
" render_interpolation(args, anim_args)\n",
|
|
"else:\n",
|
|
" render_image_batch(args) "
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "4zV0J_YbMCTx"
|
|
},
|
|
"source": [
|
|
"# Create video from frames"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "no2jP8HTMBM0"
|
|
},
|
|
"source": [
|
|
"skip_video_for_run_all = True #@param {type: 'boolean'}\n",
|
|
"fps = 12 #@param {type:\"number\"}\n",
|
|
"#@markdown **Manual Settings**\n",
|
|
"use_manual_settings = False #@param {type:\"boolean\"}\n",
|
|
"image_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png\" #@param {type:\"string\"}\n",
|
|
"mp4_path = \"/content/drive/MyDrive/AI/StableDiffu'/content/drive/MyDrive/AI/StableDiffusion/2022-09/sion/2022-09/20220903000939.mp4\" #@param {type:\"string\"}\n",
|
|
"render_steps = True #@param {type: 'boolean'}\n",
|
|
"path_name_modifier = \"x0_pred\" #@param [\"x0_pred\",\"x\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"if skip_video_for_run_all == True:\n",
|
|
" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
|
"else:\n",
|
|
" import os\n",
|
|
" import subprocess\n",
|
|
" from base64 import b64encode\n",
|
|
"\n",
|
|
" print(f\"{image_path} -> {mp4_path}\")\n",
|
|
"\n",
|
|
" if use_manual_settings:\n",
|
|
" max_frames = \"200\" #@param {type:\"string\"}\n",
|
|
" else:\n",
|
|
" if render_steps: # render steps from a single image\n",
|
|
" fname = f\"{path_name_modifier}_%05d.png\"\n",
|
|
" all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]\n",
|
|
" newest_dir = max(all_step_dirs, key=os.path.getmtime)\n",
|
|
" image_path = os.path.join(newest_dir, fname)\n",
|
|
" print(f\"Reading images from {image_path}\")\n",
|
|
" mp4_path = os.path.join(newest_dir, f\"{args.timestring}_{path_name_modifier}.mp4\")\n",
|
|
" max_frames = str(args.steps)\n",
|
|
" else: # render images for a video\n",
|
|
" image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n",
|
|
" mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n",
|
|
" max_frames = str(anim_args.max_frames)\n",
|
|
"\n",
|
|
" # make video\n",
|
|
" cmd = [\n",
|
|
" 'ffmpeg',\n",
|
|
" '-y',\n",
|
|
" '-vcodec', 'png',\n",
|
|
" '-r', str(fps),\n",
|
|
" '-start_number', str(0),\n",
|
|
" '-i', image_path,\n",
|
|
" '-frames:v', max_frames,\n",
|
|
" '-c:v', 'libx264',\n",
|
|
" '-vf',\n",
|
|
" f'fps={fps}',\n",
|
|
" '-pix_fmt', 'yuv420p',\n",
|
|
" '-crf', '17',\n",
|
|
" '-preset', 'veryfast',\n",
|
|
" '-pattern_type', 'sequence',\n",
|
|
" mp4_path\n",
|
|
" ]\n",
|
|
" process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
|
" stdout, stderr = process.communicate()\n",
|
|
" if process.returncode != 0:\n",
|
|
" print(stderr)\n",
|
|
" raise RuntimeError(stderr)\n",
|
|
"\n",
|
|
" mp4 = open(mp4_path,'rb').read()\n",
|
|
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
|
" display.display( display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>') )"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"provenance": [],
|
|
"private_outputs": true
|
|
},
|
|
"gpuClass": "standard",
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
} |