{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "c442uQJ_gUgy"
},
"source": [
"# **Deforum Stable Diffusion v0.2**\n",
"[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).\n",
"\n",
"Notebook by [deforum](https://discord.gg/upmXXsrwZc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4knibRpAQ06"
},
"source": [
"# Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2g-f7cQmf2Nt",
"cellView": "form"
},
"source": [
"#@markdown **NVIDIA GPU**\n",
"import subprocess\n",
"sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"print(sub_p_res)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "TxIOPT0G5Lx1"
},
"source": [
"#@markdown **Model and Output Paths**\n",
"# ask for the link\n",
"print(\"Local Path Variables:\\n\")\n",
"\n",
"models_path = \"/content/models\" #@param {type:\"string\"}\n",
"output_path = \"/content/output\" #@param {type:\"string\"}\n",
"\n",
"#@markdown **Google Drive Path Variables (Optional)**\n",
"mount_google_drive = True #@param {type:\"boolean\"}\n",
"force_remount = False\n",
"\n",
"if mount_google_drive:\n",
" from google.colab import drive # type: ignore\n",
" try:\n",
" drive_path = \"/content/drive\"\n",
" drive.mount(drive_path,force_remount=force_remount)\n",
" models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n",
" output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n",
" models_path = models_path_gdrive\n",
" output_path = output_path_gdrive\n",
" except:\n",
" print(\"...error mounting drive or with drive path variables\")\n",
" print(\"...reverting to default path variables\")\n",
"\n",
"import os\n",
"os.makedirs(models_path, exist_ok=True)\n",
"os.makedirs(output_path, exist_ok=True)\n",
"\n",
"print(f\"models_path: {models_path}\")\n",
"print(f\"output_path: {output_path}\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"id": "VRNl2mfepEIe",
"cellView": "form"
},
"source": [
"#@markdown **Setup Environment**\n",
"\n",
"setup_environment = True #@param {type:\"boolean\"}\n",
"print_subprocess = False #@param {type:\"boolean\"}\n",
"\n",
"if setup_environment:\n",
" import subprocess\n",
" print(\"...setting up environment\")\n",
" all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
" ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'],\n",
" ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
" ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq'],\n",
" ]\n",
" for process in all_process:\n",
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
" if print_subprocess:\n",
" print(running)\n",
" \n",
" print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
" with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n",
" f.write('')"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"id": "81qmVZbrm4uu",
"cellView": "form"
},
"source": [
"#@markdown **Python Definitions**\n",
"import json\n",
"from IPython import display\n",
"\n",
"import argparse, glob, os, pathlib, subprocess, sys, time\n",
"import cv2\n",
"import numpy as np\n",
"import pandas as pd\n",
"import random\n",
"import requests\n",
"import shutil\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"from contextlib import contextmanager, nullcontext\n",
"from einops import rearrange, repeat\n",
"from itertools import islice\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"from pytorch_lightning import seed_everything\n",
"from skimage.exposure import match_histograms\n",
"from torchvision.utils import make_grid\n",
"from tqdm import tqdm, trange\n",
"from types import SimpleNamespace\n",
"from torch import autocast\n",
"\n",
"sys.path.append('./src/taming-transformers')\n",
"sys.path.append('./src/clip')\n",
"sys.path.append('./stable-diffusion/')\n",
"sys.path.append('./k-diffusion')\n",
"\n",
"from helpers import save_samples, sampler_fn\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.models.diffusion.plms import PLMSSampler\n",
"\n",
"from k_diffusion import sampling\n",
"from k_diffusion.external import CompVisDenoiser\n",
"\n",
"class CFGDenoiser(nn.Module):\n",
" def __init__(self, model):\n",
" super().__init__()\n",
" self.inner_model = model\n",
"\n",
" def forward(self, x, sigma, uncond, cond, cond_scale):\n",
" x_in = torch.cat([x] * 2)\n",
" sigma_in = torch.cat([sigma] * 2)\n",
" cond_in = torch.cat([uncond, cond])\n",
" uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)\n",
" return uncond + (cond - uncond) * cond_scale\n",
"\n",
"def add_noise(sample: torch.Tensor, noise_amt: float):\n",
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
"\n",
"def get_output_folder(output_path, batch_folder):\n",
" out_path = os.path.join(output_path,time.strftime('%Y-%m'))\n",
" if batch_folder != \"\":\n",
" out_path = os.path.join(out_path, batch_folder)\n",
" os.makedirs(out_path, exist_ok=True)\n",
" return out_path\n",
"\n",
"def load_img(path, shape):\n",
" if path.startswith('http://') or path.startswith('https://'):\n",
" image = Image.open(requests.get(path, stream=True).raw).convert('RGB')\n",
" else:\n",
" image = Image.open(path).convert('RGB')\n",
"\n",
" image = image.resize(shape, resample=Image.LANCZOS)\n",
" image = np.array(image).astype(np.float16) / 255.0\n",
" image = image[None].transpose(0, 3, 1, 2)\n",
" image = torch.from_numpy(image)\n",
" return 2.*image - 1.\n",
"\n",
"def maintain_colors(prev_img, color_match_sample, mode):\n",
" if mode == 'Match Frame 0 RGB':\n",
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
" elif mode == 'Match Frame 0 HSV':\n",
" prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
" color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
" matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
" return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
" else: # Match Frame 0 LAB\n",
" prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n",
" color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n",
" matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n",
" return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n",
"\n",
"def make_callback(sampler, dynamic_threshold=None, static_threshold=None): \n",
" # Creates the callback function to be passed into the samplers\n",
" # The callback function is applied to the image after each step\n",
" def dynamic_thresholding_(img, threshold):\n",
" # Dynamic thresholding from Imagen paper (May 2022)\n",
" s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n",
" s = np.max(np.append(s,1.0))\n",
" torch.clamp_(img, -1*s, s)\n",
" torch.FloatTensor.div_(img, s)\n",
"\n",
" # Callback for samplers in the k-diffusion repo, called thus:\n",
" # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n",
" def k_callback(args_dict):\n",
" if static_threshold is not None:\n",
" torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)\n",
" if dynamic_threshold is not None:\n",
" dynamic_thresholding_(args_dict['x'], dynamic_threshold)\n",
"\n",
" # Function that is called on the image (img) and step (i) at each step\n",
" def img_callback(img, i):\n",
" # Thresholding functions\n",
" if dynamic_threshold is not None:\n",
" dynamic_thresholding_(img, dynamic_threshold)\n",
" if static_threshold is not None:\n",
" torch.clamp_(img, -1*static_threshold, static_threshold)\n",
"\n",
" if sampler in [\"plms\",\"ddim\"]: \n",
" # Callback function formated for compvis latent diffusion samplers\n",
" callback = img_callback\n",
" else: \n",
" # Default callback function uses k-diffusion sampler variables\n",
" callback = k_callback\n",
"\n",
" return callback\n",
"\n",
"def generate(args, return_latent=False, return_sample=False, return_c=False):\n",
" seed_everything(args.seed)\n",
" os.makedirs(args.outdir, exist_ok=True)\n",
"\n",
" if args.sampler == 'plms':\n",
" sampler = PLMSSampler(model)\n",
" else:\n",
" sampler = DDIMSampler(model)\n",
"\n",
" model_wrap = CompVisDenoiser(model) \n",
" batch_size = args.n_samples\n",
" prompt = args.prompt\n",
" assert prompt is not None\n",
" data = [batch_size * [prompt]]\n",
"\n",
" init_latent = None\n",
" if args.init_latent is not None:\n",
" init_latent = args.init_latent\n",
" elif args.init_sample is not None:\n",
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))\n",
" elif args.init_image != None and args.init_image != '':\n",
" init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)\n",
" init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space \n",
"\n",
" sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False)\n",
"\n",
" t_enc = int((1.0-args.strength) * args.steps)\n",
"\n",
" start_code = None\n",
" if args.fixed_code and init_latent == None:\n",
" start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
"\n",
" callback = make_callback(sampler=args.sampler,\n",
" dynamic_threshold=args.dynamic_threshold, \n",
" static_threshold=args.static_threshold)\n",
"\n",
" results = []\n",
" precision_scope = autocast if args.precision == \"autocast\" else nullcontext\n",
" with torch.no_grad():\n",
" with precision_scope(\"cuda\"):\n",
" with model.ema_scope():\n",
" for prompts in data:\n",
" uc = None\n",
" if args.scale != 1.0:\n",
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
" if isinstance(prompts, tuple):\n",
" prompts = list(prompts)\n",
" c = model.get_learned_conditioning(prompts)\n",
"\n",
" if args.init_c != None:\n",
" c = args.init_c\n",
"\n",
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
" samples = sampler_fn(\n",
" c=c, \n",
" uc=uc, \n",
" args=args, \n",
" model_wrap=model_wrap, \n",
" init_latent=init_latent, \n",
" t_enc=t_enc, \n",
" device=device, \n",
" cb=callback)\n",
" else:\n",
"\n",
" if init_latent != None:\n",
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
" samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,\n",
" unconditional_conditioning=uc,)\n",
" else:\n",
" if args.sampler == 'plms' or args.sampler == 'ddim':\n",
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
" samples, _ = sampler.sample(S=args.steps,\n",
" conditioning=c,\n",
" batch_size=args.n_samples,\n",
" shape=shape,\n",
" verbose=False,\n",
" unconditional_guidance_scale=args.scale,\n",
" unconditional_conditioning=uc,\n",
" eta=args.ddim_eta,\n",
" x_T=start_code,\n",
" img_callback=callback)\n",
"\n",
" if return_latent:\n",
" results.append(samples.clone())\n",
"\n",
" x_samples = model.decode_first_stage(samples)\n",
" if return_sample:\n",
" results.append(x_samples.clone())\n",
"\n",
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" if return_c:\n",
" results.append(c.clone())\n",
"\n",
" for x_sample in x_samples:\n",
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
" results.append(image)\n",
" return results\n",
"\n",
"def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
" sample = ((sample.astype(float) / 255.0) * 2) - 1\n",
" sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)\n",
" sample = torch.from_numpy(sample)\n",
" return sample\n",
"\n",
"def sample_to_cv2(sample: torch.Tensor) -> np.ndarray:\n",
" sample_f32 = rearrange(sample.squeeze().cpu().numpy(), \"c h w -> h w c\").astype(np.float32)\n",
" sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)\n",
" sample_int8 = (sample_f32 * 255).astype(np.uint8)\n",
" return sample_int8"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "CIUJ7lWI4v53"
},
"source": [
"#@markdown **Select and Load Model**\n",
"\n",
"model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
"model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n",
"custom_config_path = \"\" #@param {type:\"string\"}\n",
"custom_checkpoint_path = \"\" #@param {type:\"string\"}\n",
"\n",
"check_sha256 = True #@param {type:\"boolean\"}\n",
"\n",
"load_on_run_all = True #@param {type: 'boolean'}\n",
"half_precision = True # needs to be fixed\n",
"\n",
"model_map = {\n",
" \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n",
" \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n",
" \"sd-v1-3-full-ema.ckpt\": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},\n",
" \"sd-v1-3.ckpt\": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},\n",
" \"sd-v1-2-full-ema.ckpt\": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},\n",
" \"sd-v1-2.ckpt\": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},\n",
" \"sd-v1-1-full-ema.ckpt\": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},\n",
" \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n",
"}\n",
"\n",
"# config path\n",
"ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n",
"if os.path.exists(ckpt_config_path):\n",
" print(f\"{ckpt_config_path} exists\")\n",
"else:\n",
" ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n",
"print(f\"Using config: {ckpt_config_path}\")\n",
"\n",
"# checkpoint path or download\n",
"ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n",
"ckpt_valid = True\n",
"if os.path.exists(ckpt_path):\n",
" print(f\"{ckpt_path} exists\")\n",
"else:\n",
" print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n",
" ckpt_valid = False\n",
"\n",
"if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n",
" import hashlib\n",
" print(\"\\n...checking sha256\")\n",
" with open(ckpt_path, \"rb\") as f:\n",
" bytes = f.read() \n",
" hash = hashlib.sha256(bytes).hexdigest()\n",
" del bytes\n",
" if model_map[model_checkpoint][\"sha256\"] == hash:\n",
" print(\"hash is correct\\n\")\n",
" else:\n",
" print(\"hash in not correct\\n\")\n",
" ckpt_valid = False\n",
"\n",
"if ckpt_valid:\n",
" print(f\"Using ckpt: {ckpt_path}\")\n",
"\n",
"def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n",
" map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n",
" print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt, map_location=map_location)\n",
" if \"global_step\" in pl_sd:\n",
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
" sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n",
" if len(m) > 0 and verbose:\n",
" print(\"missing keys:\")\n",
" print(m)\n",
" if len(u) > 0 and verbose:\n",
" print(\"unexpected keys:\")\n",
" print(u)\n",
"\n",
" if half_precision:\n",
" model = model.half().to(device)\n",
" else:\n",
" model = model.to(device)\n",
" model.eval()\n",
" return model\n",
"\n",
"if load_on_run_all and ckpt_valid:\n",
" local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n",
" model = load_model_from_config(local_config, f\"{ckpt_path}\",half_precision=half_precision)\n",
" device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
" model = model.to(device)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
"id": "ov3r4RD1tzsT"
},
"source": [
"# Settings"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0j7rgxvLvfay"
},
"source": [
"### Animation Settings"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "8HJN2TE3vh-J"
},
"source": [
"\n",
"def DeforumAnimArgs():\n",
"\n",
" #@markdown ####**Animation:**\n",
" animation_mode = 'None' #@param ['None', '2D', 'Video Input', 'Interpolation'] {type:'string'}\n",
" max_frames = 1000#@param {type:\"number\"}\n",
" border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'}\n",
"\n",
" #@markdown ####**Motion Parameters:**\n",
" key_frames = True #@param {type:\"boolean\"}\n",
" interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n",
" angle = \"0:(0)\"#@param {type:\"string\"}\n",
" zoom = \"0: (1.04)\"#@param {type:\"string\"}\n",
" translation_x = \"0: (0)\"#@param {type:\"string\"}\n",
" translation_y = \"0: (0)\"#@param {type:\"string\"}\n",
" noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n",
" strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n",
" contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
"\n",
" #@markdown ####**Coherence:**\n",
" color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
"\n",
" #@markdown ####**Video Input:**\n",
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
" extract_nth_frame = 1#@param {type:\"number\"}\n",
"\n",
" #@markdown ####**Interpolation:**\n",
" interpolate_key_frames = False #@param {type:\"boolean\"}\n",
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
" \n",
" #@markdown ####**Resume Animation:**\n",
" resume_from_timestring = False #@param {type:\"boolean\"}\n",
" resume_timestring = \"20220829210106\" #@param {type:\"string\"}\n",
"\n",
" return locals()\n",
"\n",
"anim_args = SimpleNamespace(**DeforumAnimArgs())\n",
"\n",
"def make_xform_2d(width, height, translation_x, translation_y, angle, scale):\n",
" center = (width // 2, height // 2)\n",
" trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])\n",
" rot_mat = cv2.getRotationMatrix2D(center, angle, scale)\n",
" trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
" rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
" return np.matmul(rot_mat, trans_mat)\n",
"\n",
"def parse_key_frames(string, prompt_parser=None):\n",
" import re\n",
" pattern = r'((?P[0-9]+):[\\s]*[\\(](?P[\\S\\s]*?)[\\)])'\n",
" frames = dict()\n",
" for match_object in re.finditer(pattern, string):\n",
" frame = int(match_object.groupdict()['frame'])\n",
" param = match_object.groupdict()['param']\n",
" if prompt_parser:\n",
" frames[frame] = prompt_parser(param)\n",
" else:\n",
" frames[frame] = param\n",
" if frames == {} and len(string) != 0:\n",
" raise RuntimeError('Key Frame string not correctly formatted')\n",
" return frames\n",
"\n",
"def get_inbetweens(key_frames, integer=False):\n",
" key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n",
"\n",
" for i, value in key_frames.items():\n",
" key_frame_series[i] = value\n",
" key_frame_series = key_frame_series.astype(float)\n",
" \n",
" interp_method = anim_args.interp_spline\n",
" if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
" interp_method = 'Quadratic' \n",
" if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
" interp_method = 'Linear'\n",
" \n",
" key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
" key_frame_series[anim_args.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
" key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
" if integer:\n",
" return key_frame_series.astype(int)\n",
" return key_frame_series\n",
"\n",
"\n",
"if anim_args.animation_mode == 'None':\n",
" anim_args.max_frames = 1\n",
"\n",
"if anim_args.key_frames:\n",
" angle_series = get_inbetweens(parse_key_frames(anim_args.angle))\n",
" zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom))\n",
" translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x))\n",
" translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))\n",
" noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))\n",
" strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))\n",
" contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
"id": "63UOJvU3xdPS"
},
"source": [
"### Prompts\n",
"`animation_mode: None` batches on list of *prompts*. `animation_mode: 2D` uses *animation_prompts* key frame sequence"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2ujwkGZTcGev"
},
"source": [
"\n",
"prompts = [\n",
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
" #\"the third prompt I don't want it I commented it with an\",\n",
"]\n",
"\n",
"animation_prompts = {\n",
" 0: \"a beautiful apple, trending on Artstation\",\n",
" 20: \"a beautiful banana, trending on Artstation\",\n",
" 30: \"a beautiful coconut, trending on Artstation\",\n",
" 40: \"a beautiful durian, trending on Artstation\",\n",
"}"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
"id": "s8RAo2zI-vQm"
},
"source": [
"# Run"
]
},
{
"cell_type": "code",
"metadata": {
"id": "qH74gBWDd2oq",
"cellView": "form"
},
"source": [
"def DeforumArgs():\n",
" #@markdown **Save & Display Settings**\n",
" batch_name = \"StableFun\" #@param {type:\"string\"}\n",
" outdir = get_output_folder(output_path, batch_name)\n",
" save_settings = True #@param {type:\"boolean\"}\n",
" save_samples = True #@param {type:\"boolean\"}\n",
" display_samples = True #@param {type:\"boolean\"}\n",
"\n",
" #@markdown **Image Settings**\n",
" n_samples = 1 # hidden\n",
" W = 512 #@param\n",
" H = 512 #@param\n",
" W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n",
"\n",
" #@markdown **Init Settings**\n",
" use_init = False #@param {type:\"boolean\"}\n",
" strength = 0.5 #@param {type:\"number\"}\n",
" init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n",
"\n",
" #@markdown **Sampling Settings**\n",
" seed = -1 #@param\n",
" sampler = 'klms' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n",
" steps = 50 #@param\n",
" scale = 7 #@param\n",
" ddim_eta = 0.0 #@param\n",
" dynamic_threshold = None\n",
" static_threshold = None \n",
"\n",
" #@markdown **Batch Settings**\n",
" n_batch = 4 #@param\n",
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
"\n",
" #@markdown **Grid Settings**\n",
" make_grid = False #@param {type:\"boolean\"}\n",
" grid_rows = 2 #@param \n",
"\n",
" precision = 'autocast' \n",
" fixed_code = True\n",
" C = 4\n",
" f = 8\n",
"\n",
" prompt = \"\"\n",
" timestring = \"\"\n",
" init_latent = None\n",
" init_sample = None\n",
" init_c = None\n",
"\n",
" return locals()\n",
"\n",
"\n",
"args = SimpleNamespace(**DeforumArgs())\n",
"args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
"args.strength = max(0.0, min(1.0, args.strength))\n",
"\n",
"\n",
"if args.seed == -1:\n",
" args.seed = random.randint(0, 2**32)\n",
"if anim_args.animation_mode == 'Video Input':\n",
" args.use_init = True\n",
"if not args.use_init:\n",
" args.init_image = None\n",
" args.strength = 0\n",
"if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n",
" print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n",
" args.sampler = 'klms'\n",
"if args.sampler != 'ddim':\n",
" args.ddim_eta = 0\n",
"\n",
"\n",
"def next_seed(args):\n",
" if args.seed_behavior == 'iter':\n",
" args.seed += 1\n",
" elif args.seed_behavior == 'fixed':\n",
" pass # always keep seed the same\n",
" else:\n",
" args.seed = random.randint(0, 2**32)\n",
" return args.seed\n",
"\n",
"def render_image_batch(args):\n",
" args.prompts = prompts\n",
" \n",
" # create output folder for the batch\n",
" os.makedirs(args.outdir, exist_ok=True)\n",
" if args.save_settings or args.save_samples:\n",
" print(f\"Saving to {os.path.join(args.outdir, args.timestring)}_*\")\n",
"\n",
" # save settings for the batch\n",
" if args.save_settings:\n",
" filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
" with open(filename, \"w+\", encoding=\"utf-8\") as f:\n",
" json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n",
"\n",
" index = 0\n",
" \n",
" # function for init image batching\n",
" init_array = []\n",
" if args.use_init:\n",
" if args.init_image == \"\":\n",
" raise FileNotFoundError(\"No path was given for init_image\")\n",
" if args.init_image.startswith('http://') or args.init_image.startswith('https://'):\n",
" init_array.append(args.init_image)\n",
" elif not os.path.isfile(args.init_image):\n",
" if args.init_image[-1] != \"/\": # avoids path error by adding / to end if not there\n",
" args.init_image += \"/\" \n",
" for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array\n",
" if image.split(\".\")[-1] in (\"png\", \"jpg\", \"jpeg\"):\n",
" init_array.append(args.init_image + image)\n",
" else:\n",
" init_array.append(args.init_image)\n",
" else:\n",
" init_array = [\"\"]\n",
"\n",
" # when doing large batches don't flood browser with images\n",
" clear_between_batches = args.n_batch >= 32\n",
"\n",
" for iprompt, prompt in enumerate(prompts): \n",
" args.prompt = prompt\n",
"\n",
" all_images = []\n",
"\n",
" for batch_index in range(args.n_batch):\n",
" if clear_between_batches: \n",
" display.clear_output(wait=True) \n",
" print(f\"Batch {batch_index+1} of {args.n_batch}\")\n",
" \n",
" for image in init_array: # iterates the init images\n",
" args.init_image = image\n",
" results = generate(args)\n",
" for image in results:\n",
" if args.make_grid:\n",
" all_images.append(T.functional.pil_to_tensor(image))\n",
" if args.save_samples:\n",
" filename = f\"{args.timestring}_{index:05}_{args.seed}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" if args.display_samples:\n",
" display.display(image)\n",
" index += 1\n",
" args.seed = next_seed(args)\n",
"\n",
" #print(len(all_images))\n",
" if args.make_grid:\n",
" grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n",
" grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
" filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n",
" grid_image = Image.fromarray(grid.astype(np.uint8))\n",
" grid_image.save(os.path.join(args.outdir, filename))\n",
" display.clear_output(wait=True) \n",
" display.display(grid_image)\n",
"\n",
"\n",
"def render_animation(args, anim_args):\n",
" # animations use key framed prompts\n",
" args.prompts = animation_prompts\n",
" \n",
" # resume animation\n",
" start_frame = 0\n",
" if anim_args.resume_from_timestring:\n",
" for tmp in os.listdir(args.outdir):\n",
" if tmp.split(\"_\")[0] == anim_args.resume_timestring:\n",
" start_frame += 1\n",
" start_frame = start_frame - 1\n",
"\n",
" # create output folder for the batch\n",
" os.makedirs(args.outdir, exist_ok=True)\n",
" print(f\"Saving animation frames to {args.outdir}\")\n",
"\n",
" # save settings for the batch\n",
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
" \n",
" # resume from timestring\n",
" if anim_args.resume_from_timestring:\n",
" args.timestring = anim_args.resume_timestring\n",
"\n",
" # expand prompts out to per-frame\n",
" prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n",
" for i, prompt in animation_prompts.items():\n",
" prompt_series[i] = prompt\n",
" prompt_series = prompt_series.ffill().bfill()\n",
"\n",
" # check for video inits\n",
" using_vid_init = anim_args.animation_mode == 'Video Input'\n",
"\n",
" args.n_samples = 1\n",
" prev_sample = None\n",
" color_match_sample = None\n",
" for frame_idx in range(start_frame,anim_args.max_frames):\n",
" print(f\"Rendering animation frame {frame_idx} of {anim_args.max_frames}\")\n",
" \n",
" # resume animation\n",
" if anim_args.resume_from_timestring:\n",
" path = os.path.join(args.outdir,f\"{args.timestring}_{frame_idx-1:05}.png\")\n",
" img = cv2.imread(path)\n",
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
" prev_sample = sample_from_cv2(img)\n",
"\n",
" # apply transforms to previous frame\n",
" if prev_sample is not None:\n",
" if anim_args.key_frames:\n",
" angle = angle_series[frame_idx]\n",
" zoom = zoom_series[frame_idx]\n",
" translation_x = translation_x_series[frame_idx]\n",
" translation_y = translation_y_series[frame_idx]\n",
" noise = noise_schedule_series[frame_idx]\n",
" strength = strength_schedule_series[frame_idx]\n",
" contrast = contrast_schedule_series[frame_idx]\n",
" print(\n",
" f'angle: {angle}',\n",
" f'zoom: {zoom}',\n",
" f'translation_x: {translation_x}',\n",
" f'translation_y: {translation_y}',\n",
" f'noise: {noise}',\n",
" f'strength: {strength}',\n",
" f'contrast: {contrast}',\n",
" )\n",
" xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)\n",
"\n",
" # transform previous frame\n",
" prev_img = sample_to_cv2(prev_sample)\n",
" prev_img = cv2.warpPerspective(\n",
" prev_img,\n",
" xform,\n",
" (prev_img.shape[1], prev_img.shape[0]),\n",
" borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE\n",
" )\n",
"\n",
" # apply color matching\n",
" if anim_args.color_coherence != 'None':\n",
" if color_match_sample is None:\n",
" color_match_sample = prev_img.copy()\n",
" else:\n",
" prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n",
"\n",
" # apply scaling\n",
" contrast_sample = prev_img * contrast\n",
" # apply frame noising\n",
" noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)\n",
"\n",
" # use transformed previous frame as init for current\n",
" args.use_init = True\n",
" args.init_sample = noised_sample.half().to(device)\n",
" args.strength = max(0.0, min(1.0, strength))\n",
"\n",
" # grab prompt for current frame\n",
" args.prompt = prompt_series[frame_idx]\n",
" print(f\"{args.prompt} {args.seed}\")\n",
"\n",
" # grab init image for current frame\n",
" if using_vid_init:\n",
" init_frame = os.path.join(args.outdir, 'inputframes', f\"{frame_idx+1:04}.jpg\") \n",
" print(f\"Using video init frame {init_frame}\")\n",
" args.init_image = init_frame\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args, return_latent=False, return_sample=True)\n",
" sample, image = results[0], results[1]\n",
" \n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" if not using_vid_init:\n",
" prev_sample = sample\n",
" \n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
"\n",
" args.seed = next_seed(args)\n",
"\n",
"def render_input_video(args, anim_args):\n",
" # create a folder for the video input frames to live in\n",
" video_in_frame_path = os.path.join(args.outdir, 'inputframes') \n",
" os.makedirs(os.path.join(args.outdir, video_in_frame_path), exist_ok=True)\n",
" \n",
" # save the video frames from input video\n",
" print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...\")\n",
" try:\n",
" for f in pathlib.Path(video_in_frame_path).glob('*.jpg'):\n",
" f.unlink()\n",
" except:\n",
" pass\n",
" vf = r'select=not(mod(n\\,'+str(anim_args.extract_nth_frame)+'))'\n",
" subprocess.run([\n",
" 'ffmpeg', '-i', f'{anim_args.video_init_path}', \n",
" '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', \n",
" '-loglevel', 'error', '-stats', \n",
" os.path.join(video_in_frame_path, '%04d.jpg')\n",
" ], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"\n",
" # determine max frames from length of input frames\n",
" anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])\n",
"\n",
" args.use_init = True\n",
" print(f\"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}\")\n",
" render_animation(args, anim_args)\n",
"\n",
"def render_interpolation(args, anim_args):\n",
" # animations use key framed prompts\n",
" args.prompts = animation_prompts\n",
"\n",
" # create output folder for the batch\n",
" os.makedirs(args.outdir, exist_ok=True)\n",
" print(f\"Saving animation frames to {args.outdir}\")\n",
"\n",
" # save settings for the batch\n",
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
" \n",
" # Interpolation Settings\n",
" args.n_samples = 1\n",
" args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available\n",
" prompts_c_s = [] # cache all the text embeddings\n",
"\n",
" print(f\"Preparing for interpolation of the following...\")\n",
"\n",
" for i, prompt in animation_prompts.items():\n",
" args.prompt = prompt\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args, return_c=True)\n",
" c, image = results[0], results[1]\n",
" prompts_c_s.append(c) \n",
" \n",
" # display.clear_output(wait=True)\n",
" display.display(image)\n",
" \n",
" args.seed = next_seed(args)\n",
"\n",
" display.clear_output(wait=True)\n",
" print(f\"Interpolation start...\")\n",
"\n",
" frame_idx = 0\n",
"\n",
" if anim_args.interpolate_key_frames:\n",
" for i in range(len(prompts_c_s)-1):\n",
" dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n",
" if dist_frames <= 0:\n",
" print(\"key frames duplicated or reversed. interpolation skipped.\")\n",
" return\n",
" else:\n",
" for j in range(dist_frames):\n",
" # interpolate the text embedding\n",
" prompt1_c = prompts_c_s[i]\n",
" prompt2_c = prompts_c_s[i+1] \n",
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args)\n",
" image = results[0]\n",
"\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" frame_idx += 1\n",
"\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
"\n",
" args.seed = next_seed(args)\n",
"\n",
" else:\n",
" for i in range(len(prompts_c_s)-1):\n",
" for j in range(anim_args.interpolate_x_frames+1):\n",
" # interpolate the text embedding\n",
" prompt1_c = prompts_c_s[i]\n",
" prompt2_c = prompts_c_s[i+1] \n",
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
"\n",
" # sample the diffusion model\n",
" results = generate(args)\n",
" image = results[0]\n",
"\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
" frame_idx += 1\n",
"\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
"\n",
" args.seed = next_seed(args)\n",
"\n",
" # generate the last prompt\n",
" args.init_c = prompts_c_s[-1]\n",
" results = generate(args)\n",
" image = results[0]\n",
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
" image.save(os.path.join(args.outdir, filename))\n",
"\n",
" display.clear_output(wait=True)\n",
" display.display(image)\n",
" args.seed = next_seed(args)\n",
"\n",
" #clear init_c\n",
" args.init_c = None\n",
"\n",
"if anim_args.animation_mode == '2D':\n",
" render_animation(args, anim_args)\n",
"elif anim_args.animation_mode == 'Video Input':\n",
" render_input_video(args, anim_args)\n",
"elif anim_args.animation_mode == 'Interpolation':\n",
" render_interpolation(args, anim_args)\n",
"else:\n",
" render_image_batch(args) "
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
"id": "4zV0J_YbMCTx"
},
"source": [
"# Create video from frames"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "no2jP8HTMBM0"
},
"source": [
"skip_video_for_run_all = True #@param {type: 'boolean'}\n",
"fps = 12#@param {type:\"number\"}\n",
"\n",
"if skip_video_for_run_all == True:\n",
" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
"else:\n",
" import subprocess\n",
" from base64 import b64encode\n",
"\n",
" image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n",
" mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n",
"\n",
" print(f\"{image_path} -> {mp4_path}\")\n",
"\n",
" # make video\n",
" cmd = [\n",
" 'ffmpeg',\n",
" '-y',\n",
" '-vcodec', 'png',\n",
" '-r', str(fps),\n",
" '-start_number', str(0),\n",
" '-i', image_path,\n",
" '-frames:v', str(anim_args.max_frames),\n",
" '-c:v', 'libx264',\n",
" '-vf',\n",
" f'fps={fps}',\n",
" '-pix_fmt', 'yuv420p',\n",
" '-crf', '17',\n",
" '-preset', 'veryfast',\n",
" mp4_path\n",
" ]\n",
" process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
" stdout, stderr = process.communicate()\n",
" if process.returncode != 0:\n",
" print(stderr)\n",
" raise RuntimeError(stderr)\n",
"\n",
" mp4 = open(mp4_path,'rb').read()\n",
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" display.display( display.HTML(f'') )"
],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Deforum_Stable_Diffusion.ipynb",
"provenance": [],
"private_outputs": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 4
}