1058 lines
50 KiB
Plaintext
1058 lines
50 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c442uQJ_gUgy"
|
|
},
|
|
"source": [
|
|
"# **Deforum Stable Diffusion v0.1**\n",
|
|
"[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).\n",
|
|
"\n",
|
|
"Notebook by [deforum](https://discord.gg/upmXXsrwZc)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2g-f7cQmf2Nt",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **NVIDIA GPU**\n",
|
|
"import subprocess\n",
|
|
"sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
"print(sub_p_res)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "TxIOPT0G5Lx1"
|
|
},
|
|
"source": [
|
|
"#@markdown **Model Path Variables**\n",
|
|
"# ask for the link\n",
|
|
"print(\"Local Path Variables:\\n\")\n",
|
|
"\n",
|
|
"models_path = \"/content/models\" #@param {type:\"string\"}\n",
|
|
"output_path = \"/content/output\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"#@markdown **Google Drive Path Variables (Optional)**\n",
|
|
"mount_google_drive = True #@param {type:\"boolean\"}\n",
|
|
"force_remount = False\n",
|
|
"\n",
|
|
"if mount_google_drive:\n",
|
|
" from google.colab import drive\n",
|
|
" try:\n",
|
|
" drive_path = \"/content/drive\"\n",
|
|
" drive.mount(drive_path,force_remount=force_remount)\n",
|
|
" models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n",
|
|
" output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n",
|
|
" models_path = models_path_gdrive\n",
|
|
" output_path = output_path_gdrive\n",
|
|
" except:\n",
|
|
" print(\"...error mounting drive or with drive path variables\")\n",
|
|
" print(\"...reverting to default path variables\")\n",
|
|
"\n",
|
|
"import os\n",
|
|
"os.makedirs(models_path, exist_ok=True)\n",
|
|
"os.makedirs(output_path, exist_ok=True)\n",
|
|
"\n",
|
|
"print(f\"models_path: {models_path}\")\n",
|
|
"print(f\"output_path: {output_path}\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "VRNl2mfepEIe",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **Setup Environment**\n",
|
|
"\n",
|
|
"setup_environment = True #@param {type:\"boolean\"}\n",
|
|
"print_subprocess = False #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"if setup_environment:\n",
|
|
" import subprocess\n",
|
|
" print(\"...setting up environment\")\n",
|
|
" all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
|
|
" ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'],\n",
|
|
" ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
|
|
" ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq'],\n",
|
|
" ]\n",
|
|
" for process in all_process:\n",
|
|
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
" if print_subprocess:\n",
|
|
" print(running)\n",
|
|
" \n",
|
|
" print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
|
|
" with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n",
|
|
" f.write('')"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "81qmVZbrm4uu",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **Python Definitions**\n",
|
|
"import json\n",
|
|
"from IPython import display\n",
|
|
"\n",
|
|
"import argparse, glob, os, pathlib, subprocess, sys, time\n",
|
|
"import cv2\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import random\n",
|
|
"import requests\n",
|
|
"import shutil\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torchvision.transforms as T\n",
|
|
"import torchvision.transforms.functional as TF\n",
|
|
"from contextlib import contextmanager, nullcontext\n",
|
|
"from einops import rearrange, repeat\n",
|
|
"from itertools import islice\n",
|
|
"from omegaconf import OmegaConf\n",
|
|
"from PIL import Image\n",
|
|
"from pytorch_lightning import seed_everything\n",
|
|
"from skimage.exposure import match_histograms\n",
|
|
"from torchvision.utils import make_grid\n",
|
|
"from tqdm import tqdm, trange\n",
|
|
"from types import SimpleNamespace\n",
|
|
"from torch import autocast\n",
|
|
"\n",
|
|
"sys.path.append('./src/taming-transformers')\n",
|
|
"sys.path.append('./src/clip')\n",
|
|
"sys.path.append('./stable-diffusion/')\n",
|
|
"sys.path.append('./k-diffusion')\n",
|
|
"\n",
|
|
"from helpers import save_samples\n",
|
|
"from ldm.util import instantiate_from_config\n",
|
|
"from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
"from ldm.models.diffusion.plms import PLMSSampler\n",
|
|
"\n",
|
|
"from k_diffusion import sampling\n",
|
|
"from k_diffusion.external import CompVisDenoiser\n",
|
|
"\n",
|
|
"class CFGDenoiser(nn.Module):\n",
|
|
" def __init__(self, model):\n",
|
|
" super().__init__()\n",
|
|
" self.inner_model = model\n",
|
|
"\n",
|
|
" def forward(self, x, sigma, uncond, cond, cond_scale):\n",
|
|
" x_in = torch.cat([x] * 2)\n",
|
|
" sigma_in = torch.cat([sigma] * 2)\n",
|
|
" cond_in = torch.cat([uncond, cond])\n",
|
|
" uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)\n",
|
|
" return uncond + (cond - uncond) * cond_scale\n",
|
|
"\n",
|
|
"def add_noise(sample: torch.Tensor, noise_amt: float):\n",
|
|
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
|
|
"\n",
|
|
"def get_output_folder(output_path,batch_folder=None):\n",
|
|
" yearMonth = time.strftime('%Y-%m/')\n",
|
|
" out_path = os.path.join(output_path,yearMonth)\n",
|
|
" if batch_folder != \"\":\n",
|
|
" out_path = os.path.join(out_path,batch_folder)\n",
|
|
" # we will also make sure the path suffix is a slash if linux and a backslash if windows\n",
|
|
" if out_path[-1] != os.path.sep:\n",
|
|
" out_path += os.path.sep\n",
|
|
" os.makedirs(out_path, exist_ok=True)\n",
|
|
" return out_path\n",
|
|
"\n",
|
|
"def load_img(path, shape):\n",
|
|
" if path.startswith('http://') or path.startswith('https://'):\n",
|
|
" image = Image.open(requests.get(path, stream=True).raw).convert('RGB')\n",
|
|
" else:\n",
|
|
" image = Image.open(path).convert('RGB')\n",
|
|
"\n",
|
|
" image = image.resize(shape, resample=Image.LANCZOS)\n",
|
|
" image = np.array(image).astype(np.float16) / 255.0\n",
|
|
" image = image[None].transpose(0, 3, 1, 2)\n",
|
|
" image = torch.from_numpy(image)\n",
|
|
" return 2.*image - 1.\n",
|
|
"\n",
|
|
"def maintain_colors(prev_img, color_match_sample, hsv=False):\n",
|
|
" if hsv:\n",
|
|
" prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
|
|
" color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
|
|
" matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
|
|
" return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
|
|
" else:\n",
|
|
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
|
|
"\n",
|
|
"def make_callback(sampler, dynamic_threshold=None, static_threshold=None): \n",
|
|
" # Creates the callback function to be passed into the samplers\n",
|
|
" # The callback function is applied to the image after each step\n",
|
|
" def dynamic_thresholding_(img, threshold):\n",
|
|
" # Dynamic thresholding from Imagen paper (May 2022)\n",
|
|
" s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n",
|
|
" s = np.max(np.append(s,1.0))\n",
|
|
" torch.clamp_(img, -1*s, s)\n",
|
|
" torch.FloatTensor.div_(img, s)\n",
|
|
"\n",
|
|
" # Callback for samplers in the k-diffusion repo, called thus:\n",
|
|
" # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n",
|
|
" def k_callback(args_dict):\n",
|
|
" if static_threshold is not None:\n",
|
|
" torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)\n",
|
|
" if dynamic_threshold is not None:\n",
|
|
" dynamic_thresholding_(args_dict['x'], dynamic_threshold)\n",
|
|
"\n",
|
|
" # Function that is called on the image (img) and step (i) at each step\n",
|
|
" def img_callback(img, i):\n",
|
|
" # Thresholding functions\n",
|
|
" if dynamic_threshold is not None:\n",
|
|
" dynamic_thresholding_(img, dynamic_threshold)\n",
|
|
" if static_threshold is not None:\n",
|
|
" torch.clamp_(img, -1*static_threshold, static_threshold)\n",
|
|
"\n",
|
|
" if sampler in [\"plms\",\"ddim\"]: \n",
|
|
" # Callback function formated for compvis latent diffusion samplers\n",
|
|
" callback = img_callback\n",
|
|
" else: \n",
|
|
" # Default callback function uses k-diffusion sampler variables\n",
|
|
" callback = k_callback\n",
|
|
"\n",
|
|
" return callback\n",
|
|
"\n",
|
|
"def generate(args, return_latent=False, return_sample=False, return_c=False):\n",
|
|
" seed_everything(args.seed)\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
"\n",
|
|
" if args.sampler == 'plms':\n",
|
|
" sampler = PLMSSampler(model)\n",
|
|
" else:\n",
|
|
" sampler = DDIMSampler(model)\n",
|
|
"\n",
|
|
" model_wrap = CompVisDenoiser(model) \n",
|
|
" batch_size = args.n_samples\n",
|
|
" prompt = args.prompt\n",
|
|
" assert prompt is not None\n",
|
|
" data = [batch_size * [prompt]]\n",
|
|
"\n",
|
|
" init_latent = None\n",
|
|
" if args.init_latent is not None:\n",
|
|
" init_latent = args.init_latent\n",
|
|
" elif args.init_sample is not None:\n",
|
|
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))\n",
|
|
" elif args.init_image != None and args.init_image != '':\n",
|
|
" init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)\n",
|
|
" init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
|
|
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space \n",
|
|
"\n",
|
|
" sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False)\n",
|
|
"\n",
|
|
" t_enc = int((1.0-args.strength) * args.steps)\n",
|
|
"\n",
|
|
" start_code = None\n",
|
|
" if args.fixed_code and init_latent == None:\n",
|
|
" start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
|
|
"\n",
|
|
" callback = make_callback(sampler=args.sampler,\n",
|
|
" dynamic_threshold=args.dynamic_threshold, \n",
|
|
" static_threshold=args.static_threshold)\n",
|
|
"\n",
|
|
" results = []\n",
|
|
" precision_scope = autocast if args.precision == \"autocast\" else nullcontext\n",
|
|
" with torch.no_grad():\n",
|
|
" with precision_scope(\"cuda\"):\n",
|
|
" with model.ema_scope():\n",
|
|
" for n in range(args.n_samples):\n",
|
|
" for prompts in data:\n",
|
|
" uc = None\n",
|
|
" if args.scale != 1.0:\n",
|
|
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
|
|
" if isinstance(prompts, tuple):\n",
|
|
" prompts = list(prompts)\n",
|
|
" c = model.get_learned_conditioning(prompts)\n",
|
|
"\n",
|
|
" if args.init_c != None:\n",
|
|
" c = args.init_c\n",
|
|
"\n",
|
|
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
|
|
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
|
|
" sigmas = model_wrap.get_sigmas(args.steps)\n",
|
|
" if args.use_init:\n",
|
|
" sigmas = sigmas[len(sigmas)-t_enc-1:]\n",
|
|
" x = init_latent + torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n",
|
|
" else:\n",
|
|
" x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n",
|
|
" model_wrap_cfg = CFGDenoiser(model_wrap)\n",
|
|
" extra_args = {'cond': c, 'uncond': uc, 'cond_scale': args.scale}\n",
|
|
" if args.sampler==\"klms\":\n",
|
|
" samples = sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" elif args.sampler==\"dpm2\":\n",
|
|
" samples = sampling.sample_dpm_2(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" elif args.sampler==\"dpm2_ancestral\":\n",
|
|
" samples = sampling.sample_dpm_2_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" elif args.sampler==\"heun\":\n",
|
|
" samples = sampling.sample_heun(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" elif args.sampler==\"euler\":\n",
|
|
" samples = sampling.sample_euler(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" elif args.sampler==\"euler_ancestral\":\n",
|
|
" samples = sampling.sample_euler_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)\n",
|
|
" else:\n",
|
|
"\n",
|
|
" if init_latent != None:\n",
|
|
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
|
|
" samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,)\n",
|
|
" else:\n",
|
|
" if args.sampler == 'plms' or args.sampler == 'ddim':\n",
|
|
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
|
|
" samples, _ = sampler.sample(S=args.steps,\n",
|
|
" conditioning=c,\n",
|
|
" batch_size=args.n_samples,\n",
|
|
" shape=shape,\n",
|
|
" verbose=False,\n",
|
|
" unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,\n",
|
|
" eta=args.ddim_eta,\n",
|
|
" x_T=start_code,\n",
|
|
" img_callback=callback)\n",
|
|
"\n",
|
|
" if return_latent:\n",
|
|
" results.append(samples.clone())\n",
|
|
"\n",
|
|
" x_samples = model.decode_first_stage(samples)\n",
|
|
" if return_sample:\n",
|
|
" results.append(x_samples.clone())\n",
|
|
"\n",
|
|
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
|
"\n",
|
|
" if return_c:\n",
|
|
" results.append(c.clone())\n",
|
|
"\n",
|
|
" for x_sample in x_samples:\n",
|
|
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
|
|
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
|
|
" results.append(image)\n",
|
|
" return results\n",
|
|
"\n",
|
|
"def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
|
|
" sample = ((sample.astype(float) / 255.0) * 2) - 1\n",
|
|
" sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)\n",
|
|
" sample = torch.from_numpy(sample)\n",
|
|
" return sample\n",
|
|
"\n",
|
|
"def sample_to_cv2(sample: torch.Tensor) -> np.ndarray:\n",
|
|
" sample_f32 = rearrange(sample.squeeze().cpu().numpy(), \"c h w -> h w c\").astype(np.float32)\n",
|
|
" sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)\n",
|
|
" sample_int8 = (sample_f32 * 255).astype(np.uint8)\n",
|
|
" return sample_int8"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "CIUJ7lWI4v53"
|
|
},
|
|
"source": [
|
|
"#@markdown **Select Model**\n",
|
|
"print(\"\\nSelect Model:\\n\")\n",
|
|
"\n",
|
|
"model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
|
|
"model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n",
|
|
"custom_config_path = \"\" #@param {type:\"string\"}\n",
|
|
"custom_checkpoint_path = \"\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"check_sha256 = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"model_map = {\n",
|
|
" \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n",
|
|
" \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n",
|
|
" \"sd-v1-3-full-ema.ckpt\": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},\n",
|
|
" \"sd-v1-3.ckpt\": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},\n",
|
|
" \"sd-v1-2-full-ema.ckpt\": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},\n",
|
|
" \"sd-v1-2.ckpt\": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},\n",
|
|
" \"sd-v1-1-full-ema.ckpt\": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},\n",
|
|
" \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n",
|
|
"}\n",
|
|
"\n",
|
|
"def wget(url, outputdir):\n",
|
|
" res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
" print(res)\n",
|
|
"\n",
|
|
"def download_model(model_checkpoint):\n",
|
|
" download_link = model_map[model_checkpoint][\"link\"][0]\n",
|
|
" print(f\"!wget -O {models_path}/{model_checkpoint} {download_link}\")\n",
|
|
" wget(download_link, models_path)\n",
|
|
" return\n",
|
|
"\n",
|
|
"# config path\n",
|
|
"if os.path.exists(models_path+'/'+model_config):\n",
|
|
" print(f\"{models_path+'/'+model_config} exists\")\n",
|
|
"else:\n",
|
|
" print(\"cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.\")\n",
|
|
" shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path)\n",
|
|
"\n",
|
|
"# checkpoint path or download\n",
|
|
"if os.path.exists(models_path+'/'+model_checkpoint):\n",
|
|
" print(f\"{models_path+'/'+model_checkpoint} exists\")\n",
|
|
"else:\n",
|
|
" print(f\"download model checkpoint and place in {models_path+'/'+model_checkpoint}\")\n",
|
|
" #download_model(model_checkpoint)\n",
|
|
"\n",
|
|
"if check_sha256:\n",
|
|
" import hashlib\n",
|
|
" print(\"\\n...checking sha256\")\n",
|
|
" with open(models_path+'/'+model_checkpoint, \"rb\") as f:\n",
|
|
" bytes = f.read() \n",
|
|
" hash = hashlib.sha256(bytes).hexdigest()\n",
|
|
" del bytes\n",
|
|
" if model_map[model_checkpoint][\"sha256\"] == hash:\n",
|
|
" print(\"hash is correct\\n\")\n",
|
|
" else:\n",
|
|
" print(\"hash in not correct\\n\")\n",
|
|
"\n",
|
|
"if model_config == \"custom\":\n",
|
|
" config = custom_config_path\n",
|
|
"else:\n",
|
|
" config = models_path+'/'+model_config\n",
|
|
"\n",
|
|
"if model_checkpoint == \"custom\":\n",
|
|
" ckpt = custom_checkpoint_path\n",
|
|
"else:\n",
|
|
" ckpt = models_path+'/'+model_checkpoint\n",
|
|
"\n",
|
|
"print(f\"config: {config}\")\n",
|
|
"print(f\"ckpt: {ckpt}\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "IJiMgz_96nr3"
|
|
},
|
|
"source": [
|
|
"#@markdown **Load Stable Diffusion**\n",
|
|
"\n",
|
|
"def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n",
|
|
" map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n",
|
|
" print(f\"Loading model from {ckpt}\")\n",
|
|
" pl_sd = torch.load(ckpt, map_location=map_location)\n",
|
|
" if \"global_step\" in pl_sd:\n",
|
|
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
|
|
" sd = pl_sd[\"state_dict\"]\n",
|
|
" model = instantiate_from_config(config.model)\n",
|
|
" m, u = model.load_state_dict(sd, strict=False)\n",
|
|
" if len(m) > 0 and verbose:\n",
|
|
" print(\"missing keys:\")\n",
|
|
" print(m)\n",
|
|
" if len(u) > 0 and verbose:\n",
|
|
" print(\"unexpected keys:\")\n",
|
|
" print(u)\n",
|
|
"\n",
|
|
" #model.cuda()\n",
|
|
" if half_precision:\n",
|
|
" model = model.half().to(device)\n",
|
|
" else:\n",
|
|
" model = model.to(device)\n",
|
|
" model.eval()\n",
|
|
" return model\n",
|
|
"\n",
|
|
"load_on_run_all = True #@param {type: 'boolean'}\n",
|
|
"half_precision = True # needs to be fixed\n",
|
|
"\n",
|
|
"if load_on_run_all:\n",
|
|
"\n",
|
|
" local_config = OmegaConf.load(f\"{config}\")\n",
|
|
" model = load_model_from_config(local_config, f\"{ckpt}\",half_precision=half_precision)\n",
|
|
" device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
|
" model = model.to(device)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ov3r4RD1tzsT"
|
|
},
|
|
"source": [
|
|
"# Settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0j7rgxvLvfay"
|
|
},
|
|
"source": [
|
|
"### Animation Settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "8HJN2TE3vh-J"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"def DeforumAnimArgs():\n",
|
|
"\n",
|
|
" #@markdown ####**Animation:**\n",
|
|
" animation_mode = 'None' #@param ['None', '2D', 'Video Input', 'Interpolation'] {type:'string'}\n",
|
|
" max_frames = 1000#@param {type:\"number\"}\n",
|
|
" border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'}\n",
|
|
"\n",
|
|
" #@markdown ####**Motion Parameters:**\n",
|
|
" key_frames = True #@param {type:\"boolean\"}\n",
|
|
" interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n",
|
|
" angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
" zoom = \"0: (1.04)\"#@param {type:\"string\"}\n",
|
|
" translation_x = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
" translation_y = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Coherence:**\n",
|
|
" color_coherence = 'MatchFrame0' #@param ['None', 'MatchFrame0'] {type:'string'}\n",
|
|
" previous_frame_noise = 0.02#@param {type:\"number\"}\n",
|
|
" previous_frame_strength = 0.65 #@param {type:\"number\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Video Input:**\n",
|
|
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
|
" extract_nth_frame = 1#@param {type:\"number\"}\n",
|
|
"\n",
|
|
" #@markdown ####**Interpolation:**\n",
|
|
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
|
|
"\n",
|
|
" return locals()\n",
|
|
"\n",
|
|
"anim_args = SimpleNamespace(**DeforumAnimArgs())\n",
|
|
"\n",
|
|
"def make_xform_2d(width, height, translation_x, translation_y, angle, scale):\n",
|
|
" center = (width // 2, height // 2)\n",
|
|
" trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])\n",
|
|
" rot_mat = cv2.getRotationMatrix2D(center, angle, scale)\n",
|
|
" trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
|
|
" rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
|
|
" return np.matmul(rot_mat, trans_mat)\n",
|
|
"\n",
|
|
"def parse_key_frames(string, prompt_parser=None):\n",
|
|
" import re\n",
|
|
" pattern = r'((?P<frame>[0-9]+):[\\s]*[\\(](?P<param>[\\S\\s]*?)[\\)])'\n",
|
|
" frames = dict()\n",
|
|
" for match_object in re.finditer(pattern, string):\n",
|
|
" frame = int(match_object.groupdict()['frame'])\n",
|
|
" param = match_object.groupdict()['param']\n",
|
|
" if prompt_parser:\n",
|
|
" frames[frame] = prompt_parser(param)\n",
|
|
" else:\n",
|
|
" frames[frame] = param\n",
|
|
" if frames == {} and len(string) != 0:\n",
|
|
" raise RuntimeError('Key Frame string not correctly formatted')\n",
|
|
" return frames\n",
|
|
"\n",
|
|
"def get_inbetweens(key_frames, integer=False):\n",
|
|
" key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n",
|
|
"\n",
|
|
" for i, value in key_frames.items():\n",
|
|
" key_frame_series[i] = value\n",
|
|
" key_frame_series = key_frame_series.astype(float)\n",
|
|
" \n",
|
|
" interp_method = anim_args.interp_spline\n",
|
|
" if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
|
|
" interp_method = 'Quadratic' \n",
|
|
" if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
|
|
" interp_method = 'Linear'\n",
|
|
" \n",
|
|
" key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
|
|
" key_frame_series[anim_args.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
|
|
" key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
|
|
" if integer:\n",
|
|
" return key_frame_series.astype(int)\n",
|
|
" return key_frame_series\n",
|
|
"\n",
|
|
"\n",
|
|
"if anim_args.animation_mode == 'None':\n",
|
|
" anim_args.max_frames = 1\n",
|
|
"\n",
|
|
"if anim_args.key_frames:\n",
|
|
" angle_series = get_inbetweens(parse_key_frames(anim_args.angle))\n",
|
|
" zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom))\n",
|
|
" translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x))\n",
|
|
" translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "63UOJvU3xdPS"
|
|
},
|
|
"source": [
|
|
"### Prompts\n",
|
|
"`animation_mode: None` batches on list of *prompts*. `animation_mode: 2D` uses *animation_prompts* key frame sequence"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2ujwkGZTcGev"
|
|
},
|
|
"source": [
|
|
"prompts = [\n",
|
|
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
|
|
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
|
|
" #\"the third prompt I don't want it I commented it with an\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"animation_prompts = {\n",
|
|
" 0: \"a beautiful apple, trending on Artstation\",\n",
|
|
" 10: \"a beautiful banana, trending on Artstation\",\n",
|
|
" 100: \"a beautiful coconut, trending on Artstation\",\n",
|
|
" 101: \"a beautiful durian, trending on Artstation\",\n",
|
|
"}"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "s8RAo2zI-vQm"
|
|
},
|
|
"source": [
|
|
"# Run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "qH74gBWDd2oq",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"def DeforumArgs():\n",
|
|
" #@markdown **Save & Display Settings**\n",
|
|
" batch_name = \"StableFun\" #@param {type:\"string\"}\n",
|
|
" outdir = get_output_folder(output_path, batch_name)\n",
|
|
" save_grid = False\n",
|
|
" save_settings = True #@param {type:\"boolean\"}\n",
|
|
" save_samples = True #@param {type:\"boolean\"}\n",
|
|
" display_samples = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
" #@markdown **Image Settings**\n",
|
|
" n_samples = 1 #@param\n",
|
|
" W = 512 #@param\n",
|
|
" H = 512 #@param\n",
|
|
" W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n",
|
|
"\n",
|
|
" #@markdown **Init Settings**\n",
|
|
" use_init = False #@param {type:\"boolean\"}\n",
|
|
" strength = 0.5 #@param {type:\"number\"}\n",
|
|
" init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
" #@markdown **Sampling Settings**\n",
|
|
" seed = -1 #@param\n",
|
|
" sampler = 'klms' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n",
|
|
" steps = 10 #@param\n",
|
|
" scale = 7 #@param\n",
|
|
" ddim_eta = 0.0 #@param\n",
|
|
" dynamic_threshold = None\n",
|
|
" static_threshold = None \n",
|
|
"\n",
|
|
" #@markdown **Batch Settings**\n",
|
|
" n_batch = 1 #@param\n",
|
|
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
|
|
"\n",
|
|
" precision = 'autocast' \n",
|
|
" fixed_code = True\n",
|
|
" C = 4\n",
|
|
" f = 8\n",
|
|
"\n",
|
|
" prompt = \"\"\n",
|
|
" timestring = \"\"\n",
|
|
" init_latent = None\n",
|
|
" init_sample = None\n",
|
|
" init_c = None\n",
|
|
"\n",
|
|
" return locals()\n",
|
|
"\n",
|
|
"def next_seed(args):\n",
|
|
" if args.seed_behavior == 'iter':\n",
|
|
" args.seed += 1\n",
|
|
" elif args.seed_behavior == 'fixed':\n",
|
|
" pass # always keep seed the same\n",
|
|
" else:\n",
|
|
" args.seed = random.randint(0, 2**32)\n",
|
|
" return args.seed\n",
|
|
"\n",
|
|
"\n",
|
|
"args = SimpleNamespace(**DeforumArgs())\n",
|
|
"args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
|
|
"args.strength = max(0.0, min(1.0, args.strength))\n",
|
|
"\n",
|
|
"if args.seed == -1:\n",
|
|
" args.seed = random.randint(0, 2**32)\n",
|
|
"if anim_args.animation_mode == 'Video Input':\n",
|
|
" args.use_init = True\n",
|
|
"if not args.use_init:\n",
|
|
" args.init_image = None\n",
|
|
" args.strength = 0\n",
|
|
"if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n",
|
|
" print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n",
|
|
" args.sampler = 'klms'\n",
|
|
"if args.sampler != 'ddim':\n",
|
|
" args.ddim_eta = 0\n",
|
|
"\n",
|
|
"def render_image_batch(args):\n",
|
|
" args.prompts = prompts\n",
|
|
" \n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" if args.save_settings or args.save_samples:\n",
|
|
" print(f\"Saving to {os.path.join(args.outdir, args.timestring)}_*\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" if args.save_settings:\n",
|
|
" filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n",
|
|
"\n",
|
|
" index = 0\n",
|
|
" \n",
|
|
" # function for init image batching\n",
|
|
" init_array = []\n",
|
|
" if args.use_init:\n",
|
|
" if args.init_image == \"\":\n",
|
|
" raise FileNotFoundError(\"No path was given for init_image\")\n",
|
|
" if args.init_image.startswith('http://') or args.init_image.startswith('https://'):\n",
|
|
" init_array.append(args.init_image)\n",
|
|
" elif not os.path.isfile(args.init_image):\n",
|
|
" if args.init_image[-1] != \"/\": # avoids path error by adding / to end if not there\n",
|
|
" args.init_image += \"/\" \n",
|
|
" for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array\n",
|
|
" if image.split(\".\")[-1] in (\"png\", \"jpg\", \"jpeg\"):\n",
|
|
" init_array.append(args.init_image + image)\n",
|
|
" else:\n",
|
|
" init_array.append(args.init_image)\n",
|
|
" else:\n",
|
|
" init_array = [\"\"] \n",
|
|
"\n",
|
|
" for batch_index in range(args.n_batch):\n",
|
|
" print(f\"Batch {batch_index+1} of {args.n_batch}\")\n",
|
|
" \n",
|
|
" for image in init_array: # iterates the init images\n",
|
|
" args.init_image = image\n",
|
|
" for prompt in prompts:\n",
|
|
" args.prompt = prompt\n",
|
|
" results = generate(args)\n",
|
|
" for image in results:\n",
|
|
" if args.save_samples:\n",
|
|
" filename = f\"{args.timestring}_{index:05}_{args.seed}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" if args.display_samples:\n",
|
|
" display.display(image)\n",
|
|
" index += 1\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
"\n",
|
|
"def render_animation(args, anim_args):\n",
|
|
" # animations use key framed prompts\n",
|
|
" args.prompts = animation_prompts\n",
|
|
"\n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" print(f\"Saving animation frames to {args.outdir}\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
|
|
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
|
|
"\n",
|
|
" # expand prompts out to per-frame\n",
|
|
" prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n",
|
|
" for i, prompt in animation_prompts.items():\n",
|
|
" prompt_series[i] = prompt\n",
|
|
" prompt_series = prompt_series.ffill().bfill()\n",
|
|
"\n",
|
|
" # check for video inits\n",
|
|
" using_vid_init = anim_args.animation_mode == 'Video Input'\n",
|
|
"\n",
|
|
" args.n_samples = 1\n",
|
|
" prev_sample = None\n",
|
|
" color_match_sample = None\n",
|
|
" for frame_idx in range(anim_args.max_frames):\n",
|
|
" print(f\"Rendering animation frame {frame_idx} of {anim_args.max_frames}\")\n",
|
|
"\n",
|
|
" # apply transforms to previous frame\n",
|
|
" if prev_sample is not None:\n",
|
|
" if anim_args.key_frames:\n",
|
|
" angle = angle_series[frame_idx]\n",
|
|
" zoom = zoom_series[frame_idx]\n",
|
|
" translation_x = translation_x_series[frame_idx]\n",
|
|
" translation_y = translation_y_series[frame_idx]\n",
|
|
" print(\n",
|
|
" f'angle: {angle}',\n",
|
|
" f'zoom: {zoom}',\n",
|
|
" f'translation_x: {translation_x}',\n",
|
|
" f'translation_y: {translation_y}',\n",
|
|
" )\n",
|
|
" xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)\n",
|
|
"\n",
|
|
" # transform previous frame\n",
|
|
" prev_img = sample_to_cv2(prev_sample)\n",
|
|
" prev_img = cv2.warpPerspective(\n",
|
|
" prev_img,\n",
|
|
" xform,\n",
|
|
" (prev_img.shape[1], prev_img.shape[0]),\n",
|
|
" borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE\n",
|
|
" )\n",
|
|
"\n",
|
|
" # apply color matching\n",
|
|
" if anim_args.color_coherence == 'MatchFrame0':\n",
|
|
" if color_match_sample is None:\n",
|
|
" color_match_sample = prev_img.copy()\n",
|
|
" else:\n",
|
|
" prev_img = maintain_colors(prev_img, color_match_sample, (frame_idx%2) == 0)\n",
|
|
"\n",
|
|
" # apply frame noising\n",
|
|
" noised_sample = add_noise(sample_from_cv2(prev_img), anim_args.previous_frame_noise)\n",
|
|
"\n",
|
|
" # use transformed previous frame as init for current\n",
|
|
" args.use_init = True\n",
|
|
" args.init_sample = noised_sample.half().to(device)\n",
|
|
" args.strength = max(0.0, min(1.0, anim_args.previous_frame_strength))\n",
|
|
"\n",
|
|
" # grab prompt for current frame\n",
|
|
" args.prompt = prompt_series[frame_idx]\n",
|
|
" print(f\"{args.prompt} {args.seed}\")\n",
|
|
"\n",
|
|
" # grab init image for current frame\n",
|
|
" if using_vid_init:\n",
|
|
" init_frame = os.path.join(args.outdir, 'inputframes', f\"{frame_idx+1:04}.jpg\") \n",
|
|
" print(f\"Using video init frame {init_frame}\")\n",
|
|
" args.init_image = init_frame\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args, return_latent=False, return_sample=True)\n",
|
|
" sample, image = results[0], results[1]\n",
|
|
" \n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" if not using_vid_init:\n",
|
|
" prev_sample = sample\n",
|
|
" \n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
"\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
"def render_input_video(args, anim_args):\n",
|
|
" # create a folder for the video input frames to live in\n",
|
|
" video_in_frame_path = os.path.join(args.outdir, 'inputframes') \n",
|
|
" os.makedirs(os.path.join(args.outdir, video_in_frame_path), exist_ok=True)\n",
|
|
" \n",
|
|
" # save the video frames from input video\n",
|
|
" print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...\")\n",
|
|
" try:\n",
|
|
" for f in pathlib.Path(video_in_frame_path).glob('*.jpg'):\n",
|
|
" f.unlink()\n",
|
|
" except:\n",
|
|
" pass\n",
|
|
" vf = f'select=not(mod(n\\,{anim_args.extract_nth_frame}))'\n",
|
|
" subprocess.run([\n",
|
|
" 'ffmpeg', '-i', f'{anim_args.video_init_path}', \n",
|
|
" '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', \n",
|
|
" '-loglevel', 'error', '-stats', \n",
|
|
" os.path.join(video_in_frame_path, '%04d.jpg')\n",
|
|
" ], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
"\n",
|
|
" # determine max frames from length of input frames\n",
|
|
" anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])\n",
|
|
"\n",
|
|
" args.use_init = True\n",
|
|
" print(f\"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}\")\n",
|
|
" render_animation(args, anim_args)\n",
|
|
"\n",
|
|
"def render_interpolation(args, anim_args):\n",
|
|
" # animations use key framed prompts\n",
|
|
" args.prompts = animation_prompts\n",
|
|
"\n",
|
|
" # create output folder for the batch\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
" print(f\"Saving animation frames to {args.outdir}\")\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n",
|
|
" json.dump(s, f, ensure_ascii=False, indent=4)\n",
|
|
" \n",
|
|
" # Interpolation Settings\n",
|
|
" args.n_samples = 1\n",
|
|
" args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available\n",
|
|
" prompts_c_s = [] # cache all the text embeddings\n",
|
|
"\n",
|
|
" print(f\"Preparing for interpolation of the following...\")\n",
|
|
"\n",
|
|
" for i, prompt in animation_prompts.items():\n",
|
|
" args.prompt = prompt\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args, return_c=True)\n",
|
|
" c, image = results[0], results[1]\n",
|
|
" prompts_c_s.append(c) \n",
|
|
" \n",
|
|
" # display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
" \n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" print(f\"Interpolation start...\")\n",
|
|
"\n",
|
|
" frame_idx = 0\n",
|
|
"\n",
|
|
" for i in range(len(prompts_c_s)-1):\n",
|
|
" for j in range(anim_args.interpolate_x_frames+1):\n",
|
|
" # interpolate the text embedding\n",
|
|
" prompt1_c = prompts_c_s[i]\n",
|
|
" prompt2_c = prompts_c_s[i+1] \n",
|
|
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
|
|
"\n",
|
|
" # sample the diffusion model\n",
|
|
" results = generate(args)\n",
|
|
" image = results[0]\n",
|
|
"\n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
" frame_idx += 1\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
"\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" # generate the last prompt\n",
|
|
" args.init_c = prompts_c_s[-1]\n",
|
|
" results = generate(args)\n",
|
|
" image = results[0]\n",
|
|
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
|
|
" image.save(os.path.join(args.outdir, filename))\n",
|
|
"\n",
|
|
" display.clear_output(wait=True)\n",
|
|
" display.display(image)\n",
|
|
" args.seed = next_seed(args)\n",
|
|
"\n",
|
|
" #clear init_c\n",
|
|
" args.init_c = None\n",
|
|
"\n",
|
|
"if anim_args.animation_mode == '2D':\n",
|
|
" render_animation(args, anim_args)\n",
|
|
"elif anim_args.animation_mode == 'Video Input':\n",
|
|
" render_input_video(args, anim_args)\n",
|
|
"elif anim_args.animation_mode == 'Interpolation':\n",
|
|
" render_interpolation(args, anim_args)\n",
|
|
"else:\n",
|
|
" render_image_batch(args) "
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "4zV0J_YbMCTx"
|
|
},
|
|
"source": [
|
|
"# Create video from frames"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "no2jP8HTMBM0"
|
|
},
|
|
"source": [
|
|
"skip_video_for_run_all = True #@param {type: 'boolean'}\n",
|
|
"fps = 12#@param {type:\"number\"}\n",
|
|
"\n",
|
|
"if skip_video_for_run_all == True:\n",
|
|
" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
|
"else:\n",
|
|
" import subprocess\n",
|
|
" from base64 import b64encode\n",
|
|
"\n",
|
|
" image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n",
|
|
" mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n",
|
|
"\n",
|
|
" print(f\"{image_path} -> {mp4_path}\")\n",
|
|
"\n",
|
|
" # make video\n",
|
|
" cmd = [\n",
|
|
" 'ffmpeg',\n",
|
|
" '-y',\n",
|
|
" '-vcodec', 'png',\n",
|
|
" '-r', str(fps),\n",
|
|
" '-start_number', str(0),\n",
|
|
" '-i', image_path,\n",
|
|
" '-frames:v', str(anim_args.max_frames),\n",
|
|
" '-c:v', 'libx264',\n",
|
|
" '-vf',\n",
|
|
" f'fps={fps}',\n",
|
|
" '-pix_fmt', 'yuv420p',\n",
|
|
" '-crf', '17',\n",
|
|
" '-preset', 'veryfast',\n",
|
|
" mp4_path\n",
|
|
" ]\n",
|
|
" process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
|
" stdout, stderr = process.communicate()\n",
|
|
" if process.returncode != 0:\n",
|
|
" print(stderr)\n",
|
|
" raise RuntimeError(stderr)\n",
|
|
"\n",
|
|
" mp4 = open(mp4_path,'rb').read()\n",
|
|
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
|
" display.display( display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>') )"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "Deforum_Stable_Diffusion + Interpolation.ipynb",
|
|
"provenance": []
|
|
},
|
|
"gpuClass": "standard",
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
} |