622 lines
29 KiB
Plaintext
622 lines
29 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c442uQJ_gUgy"
|
|
},
|
|
"source": [
|
|
"# **Deforum Stable Diffusion**\n",
|
|
"[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team\n",
|
|
"\n",
|
|
"Notebook by [deforum](https://twitter.com/deforum_art)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2g-f7cQmf2Nt",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **NVIDIA GPU**\n",
|
|
"import subprocess\n",
|
|
"sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
"print(sub_p_res)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "VRNl2mfepEIe",
|
|
"cellView": "form"
|
|
},
|
|
"source": [
|
|
"#@markdown **Setup Environment**\n",
|
|
"\n",
|
|
"setup_environment = True #@param {type:\"boolean\"}\n",
|
|
"print_subprocess = False #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"if setup_environment:\n",
|
|
" import subprocess\n",
|
|
" print(\"...setting up environment\")\n",
|
|
" all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
|
|
" ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'],\n",
|
|
" ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n",
|
|
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
|
|
" ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq'],\n",
|
|
" ]\n",
|
|
" for process in all_process:\n",
|
|
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
" if print_subprocess:\n",
|
|
" print(running)\n",
|
|
" \n",
|
|
" print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
|
|
" with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n",
|
|
" f.write('')\n",
|
|
" \n",
|
|
" import sys\n",
|
|
" sys.path.append('./src/taming-transformers')\n",
|
|
" sys.path.append('./src/clip')\n",
|
|
" sys.path.append('./stable-diffusion/')\n",
|
|
" sys.path.append('./k-diffusion')"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "81qmVZbrm4uu"
|
|
},
|
|
"source": [
|
|
"#@markdown **Python Definitions**\n",
|
|
"import json\n",
|
|
"from IPython import display\n",
|
|
"\n",
|
|
"import sys, os\n",
|
|
"import argparse, glob\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import numpy as np\n",
|
|
"import requests\n",
|
|
"import shutil\n",
|
|
"from types import SimpleNamespace\n",
|
|
"from omegaconf import OmegaConf\n",
|
|
"from PIL import Image\n",
|
|
"from tqdm import tqdm, trange\n",
|
|
"from itertools import islice\n",
|
|
"from einops import rearrange, repeat\n",
|
|
"from torchvision.utils import make_grid\n",
|
|
"import time\n",
|
|
"from pytorch_lightning import seed_everything\n",
|
|
"from torch import autocast\n",
|
|
"from contextlib import contextmanager, nullcontext\n",
|
|
"\n",
|
|
"from helpers import save_samples\n",
|
|
"from ldm.util import instantiate_from_config\n",
|
|
"from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
"from ldm.models.diffusion.plms import PLMSSampler\n",
|
|
"\n",
|
|
"import accelerate\n",
|
|
"from k_diffusion import sampling\n",
|
|
"from k_diffusion.external import CompVisDenoiser\n",
|
|
"\n",
|
|
"def chunk(it, size):\n",
|
|
" it = iter(it)\n",
|
|
" return iter(lambda: tuple(islice(it, size)), ())\n",
|
|
"\n",
|
|
"def get_output_folder(output_path,batch_folder=None):\n",
|
|
" yearMonth = time.strftime('%Y-%m/')\n",
|
|
" out_path = output_path+\"/\"+yearMonth\n",
|
|
" if batch_folder != \"\":\n",
|
|
" out_path += batch_folder\n",
|
|
" if out_path[-1] != \"/\":\n",
|
|
" out_path += \"/\"\n",
|
|
" os.makedirs(out_path, exist_ok=True)\n",
|
|
" return out_path\n",
|
|
"\n",
|
|
"def load_img(path, shape):\n",
|
|
" if path.startswith('http://') or path.startswith('https://'):\n",
|
|
" image = Image.open(requests.get(path, stream=True).raw).convert('RGB')\n",
|
|
" else:\n",
|
|
" image = Image.open(path).convert('RGB')\n",
|
|
"\n",
|
|
" image = image.resize(shape, resample=Image.LANCZOS)\n",
|
|
" image = np.array(image).astype(np.float16) / 255.0\n",
|
|
" image = image[None].transpose(0, 3, 1, 2)\n",
|
|
" image = torch.from_numpy(image)\n",
|
|
" return 2.*image - 1.\n",
|
|
"\n",
|
|
"class CFGDenoiser(nn.Module):\n",
|
|
" def __init__(self, model):\n",
|
|
" super().__init__()\n",
|
|
" self.inner_model = model\n",
|
|
"\n",
|
|
" def forward(self, x, sigma, uncond, cond, cond_scale):\n",
|
|
" x_in = torch.cat([x] * 2)\n",
|
|
" sigma_in = torch.cat([sigma] * 2)\n",
|
|
" cond_in = torch.cat([uncond, cond])\n",
|
|
" uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)\n",
|
|
" return uncond + (cond - uncond) * cond_scale\n",
|
|
"\n",
|
|
"def make_callback(sampler, dynamic_threshold=None, static_threshold=None): \n",
|
|
" # Creates the callback function to be passed into the samplers\n",
|
|
" # The callback function is applied to the image after each step\n",
|
|
" def dynamic_thresholding_(img, threshold):\n",
|
|
" # Dynamic thresholding from Imagen paper (May 2022)\n",
|
|
" s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n",
|
|
" s = np.max(np.append(s,1.0))\n",
|
|
" torch.clamp_(img, -1*s, s)\n",
|
|
" torch.FloatTensor.div_(img, s)\n",
|
|
"\n",
|
|
" # Callback for samplers in the k-diffusion repo, called thus:\n",
|
|
" # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n",
|
|
" def k_callback(args_dict):\n",
|
|
" if static_threshold is not None:\n",
|
|
" torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)\n",
|
|
" if dynamic_threshold is not None:\n",
|
|
" dynamic_thresholding_(args_dict['x'], dynamic_threshold)\n",
|
|
"\n",
|
|
" # Function that is called on the image (img) and step (i) at each step\n",
|
|
" def img_callback(img, i):\n",
|
|
" # Thresholding functions\n",
|
|
" if dynamic_threshold is not None:\n",
|
|
" dynamic_thresholding_(img, dynamic_threshold)\n",
|
|
" if static_threshold is not None:\n",
|
|
" torch.clamp_(img, -1*static_threshold, static_threshold)\n",
|
|
"\n",
|
|
" if sampler in [\"plms\",\"ddim\"]: \n",
|
|
" # Callback function formated for compvis latent diffusion samplers\n",
|
|
" callback = img_callback\n",
|
|
" else: \n",
|
|
" # Default callback function uses k-diffusion sampler variables\n",
|
|
" callback = k_callback\n",
|
|
"\n",
|
|
" return callback\n",
|
|
"\n",
|
|
"def run(args, local_seed):\n",
|
|
"\n",
|
|
" # load settings\n",
|
|
" accelerator = accelerate.Accelerator()\n",
|
|
" device = accelerator.device\n",
|
|
" seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])\n",
|
|
" torch.manual_seed(seeds[accelerator.process_index].item())\n",
|
|
"\n",
|
|
" # plms\n",
|
|
" if args.sampler==\"plms\":\n",
|
|
" args.eta = 0\n",
|
|
" sampler = PLMSSampler(model)\n",
|
|
" else:\n",
|
|
" sampler = DDIMSampler(model)\n",
|
|
"\n",
|
|
" model_wrap = CompVisDenoiser(model)\n",
|
|
" sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n",
|
|
"\n",
|
|
" batch_size = args.n_samples\n",
|
|
" n_rows = args.n_rows if args.n_rows > 0 else batch_size\n",
|
|
"\n",
|
|
" data = list(chunk(args.prompts, batch_size))\n",
|
|
" sample_index = 0\n",
|
|
"\n",
|
|
" start_code = None\n",
|
|
" \n",
|
|
" # init image\n",
|
|
" if args.use_init:\n",
|
|
" assert os.path.isfile(args.init_image)\n",
|
|
" init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)\n",
|
|
" init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
|
|
" init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space\n",
|
|
"\n",
|
|
" sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.eta, verbose=False)\n",
|
|
"\n",
|
|
" assert 0. <= args.strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
|
|
" t_enc = int(args.strength * args.steps)\n",
|
|
" print(f\"target t_enc is {t_enc} steps\")\n",
|
|
"\n",
|
|
" # no init image\n",
|
|
" else:\n",
|
|
" if args.fixed_code:\n",
|
|
" start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
|
|
"\n",
|
|
" precision_scope = autocast if args.precision==\"autocast\" else nullcontext\n",
|
|
" with torch.no_grad():\n",
|
|
" with precision_scope(\"cuda\"):\n",
|
|
" with model.ema_scope():\n",
|
|
" tic = time.time()\n",
|
|
" for prompt_index, prompts in enumerate(data):\n",
|
|
" print(prompts)\n",
|
|
" prompt_seed = local_seed + prompt_index\n",
|
|
" seed_everything(prompt_seed)\n",
|
|
"\n",
|
|
" callback = make_callback(sampler=args.sampler,\n",
|
|
" dynamic_threshold=args.dynamic_threshold, \n",
|
|
" static_threshold=args.static_threshold) \n",
|
|
"\n",
|
|
" uc = None\n",
|
|
" if args.scale != 1.0:\n",
|
|
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
|
|
" if isinstance(prompts, tuple):\n",
|
|
" prompts = list(prompts)\n",
|
|
" c = model.get_learned_conditioning(prompts)\n",
|
|
"\n",
|
|
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
|
|
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
|
|
" sigmas = model_wrap.get_sigmas(args.steps)\n",
|
|
" torch.manual_seed(prompt_seed)\n",
|
|
" if args.use_init:\n",
|
|
" sigmas = sigmas[t_enc:]\n",
|
|
" x = init_latent + torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n",
|
|
" else:\n",
|
|
" x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n",
|
|
" model_wrap_cfg = CFGDenoiser(model_wrap)\n",
|
|
" extra_args = {'cond': c, 'uncond': uc, 'cond_scale': args.scale}\n",
|
|
" if args.sampler==\"klms\":\n",
|
|
" samples = sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
" elif args.sampler==\"dpm2\":\n",
|
|
" samples = sampling.sample_dpm_2(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
" elif args.sampler==\"dpm2_ancestral\":\n",
|
|
" samples = sampling.sample_dpm_2_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
" elif args.sampler==\"heun\":\n",
|
|
" samples = sampling.sample_heun(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
" elif args.sampler==\"euler\":\n",
|
|
" samples = sampling.sample_euler(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
" elif args.sampler==\"euler_ancestral\":\n",
|
|
" samples = sampling.sample_euler_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n",
|
|
"\n",
|
|
" x_samples = model.decode_first_stage(samples)\n",
|
|
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
|
" x_samples = accelerator.gather(x_samples)\n",
|
|
"\n",
|
|
" else:\n",
|
|
"\n",
|
|
" # no init image\n",
|
|
" if not args.use_init:\n",
|
|
" shape = [args.C, args.H // args.f, args.W // args.f]\n",
|
|
"\n",
|
|
" samples, _ = sampler.sample(S=args.steps,\n",
|
|
" conditioning=c,\n",
|
|
" batch_size=args.n_samples,\n",
|
|
" shape=shape,\n",
|
|
" verbose=False,\n",
|
|
" unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,\n",
|
|
" eta=args.eta,\n",
|
|
" x_T=start_code,\n",
|
|
" img_callback=callback)\n",
|
|
"\n",
|
|
" # init image\n",
|
|
" else:\n",
|
|
" # encode (scaled latent)\n",
|
|
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
|
|
" # decode it\n",
|
|
" samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,\n",
|
|
" unconditional_conditioning=uc,)\n",
|
|
"\n",
|
|
" x_samples = model.decode_first_stage(samples)\n",
|
|
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
|
" \n",
|
|
"\n",
|
|
" grid, images = save_samples(\n",
|
|
" args, x_samples=x_samples, seed=prompt_seed, n_rows=n_rows\n",
|
|
" )\n",
|
|
" if args.display_samples:\n",
|
|
" for im in images:\n",
|
|
" display.display(im)\n",
|
|
" if args.display_grid:\n",
|
|
" display.display(grid)\n",
|
|
"\n",
|
|
" # stop timer\n",
|
|
" toc = time.time()\n",
|
|
"\n",
|
|
" #print(f\"Your samples are ready and waiting for you here: \\n{outpath} \\n\" f\" \\nEnjoy.\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "TxIOPT0G5Lx1"
|
|
},
|
|
"source": [
|
|
"#@markdown **Model Path Variables**\n",
|
|
"# ask for the link\n",
|
|
"print(\"Local Path Variables:\\n\")\n",
|
|
"\n",
|
|
"models_path = \"/content/models\" #@param {type:\"string\"}\n",
|
|
"output_path = \"/content/output\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"#@markdown **Google Drive Path Variables (Optional)**\n",
|
|
"mount_google_drive = True #@param {type:\"boolean\"}\n",
|
|
"force_remount = False\n",
|
|
"\n",
|
|
"if mount_google_drive:\n",
|
|
" from google.colab import drive\n",
|
|
" try:\n",
|
|
" drive_path = \"/content/drive\"\n",
|
|
" drive.mount(drive_path,force_remount=force_remount)\n",
|
|
" models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n",
|
|
" output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n",
|
|
" models_path = models_path_gdrive\n",
|
|
" output_path = output_path_gdrive\n",
|
|
" except:\n",
|
|
" print(\"...error mounting drive or with drive path variables\")\n",
|
|
" print(\"...reverting to default path variables\")\n",
|
|
"\n",
|
|
"os.makedirs(models_path, exist_ok=True)\n",
|
|
"os.makedirs(output_path, exist_ok=True)\n",
|
|
"\n",
|
|
"print(f\"models_path: {models_path}\")\n",
|
|
"print(f\"output_path: {output_path}\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "CIUJ7lWI4v53"
|
|
},
|
|
"source": [
|
|
"#@markdown **Select Model**\n",
|
|
"print(\"\\nSelect Model:\\n\")\n",
|
|
"\n",
|
|
"model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
|
|
"model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n",
|
|
"custom_config_path = \"\" #@param {type:\"string\"}\n",
|
|
"custom_checkpoint_path = \"\" #@param {type:\"string\"}\n",
|
|
"\n",
|
|
"check_sha256 = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"model_map = {\n",
|
|
" \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n",
|
|
" \"sd-v1-3-full-ema.ckpt\": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},\n",
|
|
" \"sd-v1-3.ckpt\": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},\n",
|
|
" \"sd-v1-2-full-ema.ckpt\": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},\n",
|
|
" \"sd-v1-2.ckpt\": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},\n",
|
|
" \"sd-v1-1-full-ema.ckpt\": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},\n",
|
|
" \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n",
|
|
"}\n",
|
|
"\n",
|
|
"def wget(url, outputdir):\n",
|
|
" res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
" print(res)\n",
|
|
"\n",
|
|
"def download_model(model_checkpoint):\n",
|
|
" download_link = model_map[model_checkpoint][\"link\"][0]\n",
|
|
" print(f\"!wget -O {models_path}/{model_checkpoint} {download_link}\")\n",
|
|
" wget(download_link, models_path)\n",
|
|
" return\n",
|
|
"\n",
|
|
"# config path\n",
|
|
"if os.path.exists(models_path+'/'+model_config):\n",
|
|
" print(f\"{models_path+'/'+model_config} exists\")\n",
|
|
"else:\n",
|
|
" print(\"cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.\")\n",
|
|
" shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path)\n",
|
|
"\n",
|
|
"# checkpoint path or download\n",
|
|
"if os.path.exists(models_path+'/'+model_checkpoint):\n",
|
|
" print(f\"{models_path+'/'+model_checkpoint} exists\")\n",
|
|
"else:\n",
|
|
" print(f\"download model checkpoint and place in {models_path+'/'+model_checkpoint}\")\n",
|
|
" #download_model(model_checkpoint)\n",
|
|
"\n",
|
|
"if check_sha256:\n",
|
|
" import hashlib\n",
|
|
" print(\"\\n...checking sha256\")\n",
|
|
" with open(models_path+'/'+model_checkpoint, \"rb\") as f:\n",
|
|
" bytes = f.read() \n",
|
|
" hash = hashlib.sha256(bytes).hexdigest()\n",
|
|
" del bytes\n",
|
|
" if model_map[model_checkpoint][\"sha256\"] == hash:\n",
|
|
" print(\"hash is correct\\n\")\n",
|
|
" else:\n",
|
|
" print(\"hash in not correct\\n\")\n",
|
|
"\n",
|
|
"if model_config == \"custom\":\n",
|
|
" config = custom_config_path\n",
|
|
"else:\n",
|
|
" config = models_path+'/'+model_config\n",
|
|
"\n",
|
|
"if model_checkpoint == \"custom\":\n",
|
|
" ckpt = custom_checkpoint_path\n",
|
|
"else:\n",
|
|
" ckpt = models_path+'/'+model_checkpoint\n",
|
|
"\n",
|
|
"print(f\"config: {config}\")\n",
|
|
"print(f\"ckpt: {ckpt}\")"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "IJiMgz_96nr3"
|
|
},
|
|
"source": [
|
|
"#@markdown **Load Stable Diffusion**\n",
|
|
"\n",
|
|
"def load_model_from_config(config, ckpt, verbose=False, device='cuda'):\n",
|
|
" map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n",
|
|
" print(f\"Loading model from {ckpt}\")\n",
|
|
" pl_sd = torch.load(ckpt, map_location=map_location)\n",
|
|
" if \"global_step\" in pl_sd:\n",
|
|
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
|
|
" sd = pl_sd[\"state_dict\"]\n",
|
|
" model = instantiate_from_config(config.model)\n",
|
|
" m, u = model.load_state_dict(sd, strict=False)\n",
|
|
" if len(m) > 0 and verbose:\n",
|
|
" print(\"missing keys:\")\n",
|
|
" print(m)\n",
|
|
" if len(u) > 0 and verbose:\n",
|
|
" print(\"unexpected keys:\")\n",
|
|
" print(u)\n",
|
|
"\n",
|
|
" #model.cuda()\n",
|
|
" model = model.half().to(device)\n",
|
|
" model.eval()\n",
|
|
" return model\n",
|
|
"\n",
|
|
"load_on_run_all = True #@param {type: 'boolean'}\n",
|
|
"\n",
|
|
"if load_on_run_all:\n",
|
|
"\n",
|
|
" local_config = OmegaConf.load(f\"{config}\")\n",
|
|
" model = load_model_from_config(local_config, f\"{ckpt}\")\n",
|
|
" device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
|
" model = model.to(device)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ov3r4RD1tzsT"
|
|
},
|
|
"source": [
|
|
"# **Run**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "qH74gBWDd2oq"
|
|
},
|
|
"source": [
|
|
"def DeforumArgs():\n",
|
|
" #@markdown **Save & Display Settings**\n",
|
|
" batchdir = \"test\" #@param {type:\"string\"}\n",
|
|
" outdir = get_output_folder(output_path, batchdir)\n",
|
|
" save_grid = False\n",
|
|
" save_samples = True #@param {type:\"boolean\"}\n",
|
|
" save_settings = True #@param {type:\"boolean\"}\n",
|
|
" display_grid = False\n",
|
|
" display_samples = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
" #@markdown **Image Settings**\n",
|
|
" n_samples = 1 #@param\n",
|
|
" n_rows = 1 #@param\n",
|
|
" W = 512 #@param\n",
|
|
" H = 576 #@param\n",
|
|
" W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n",
|
|
"\n",
|
|
"\n",
|
|
" #@markdown **Init Settings**\n",
|
|
" use_init = False #@param {type:\"boolean\"}\n",
|
|
" init_image = \"/content/drive/MyDrive/AI/escape.jpg\" #@param {type:\"string\"}\n",
|
|
" strength = 0.5 #@param {type:\"number\"}\n",
|
|
"\n",
|
|
" #@markdown **Sampling Settings**\n",
|
|
" seed = 1 #@param\n",
|
|
" sampler = 'euler_ancestral' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n",
|
|
" steps = 50 #@param\n",
|
|
" scale = 7 #@param\n",
|
|
" eta = 0.0 #@param\n",
|
|
" dynamic_threshold = None\n",
|
|
" static_threshold = None \n",
|
|
"\n",
|
|
" #@markdown **Batch Settings**\n",
|
|
" n_batch = 2 #@param\n",
|
|
"\n",
|
|
" precision = 'autocast' \n",
|
|
" fixed_code = True\n",
|
|
" C = 4\n",
|
|
" f = 8\n",
|
|
" prompts = globals()['prompts']\n",
|
|
" timestring = \"\"\n",
|
|
"\n",
|
|
" return locals()\n",
|
|
"\n",
|
|
"args = SimpleNamespace(**DeforumArgs())"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2ujwkGZTcGev"
|
|
},
|
|
"source": [
|
|
"prompts = [\n",
|
|
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
|
|
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
|
|
" #\"the third prompt I don't want it I commented it with an\",\n",
|
|
"]"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "cxx8BzxjiaXg"
|
|
},
|
|
"source": [
|
|
"#@markdown **Run**\n",
|
|
"args = DeforumArgs()\n",
|
|
"args.filename = None\n",
|
|
"args.prompts = prompts\n",
|
|
"\n",
|
|
"def do_batch_run():\n",
|
|
" # create output folder\n",
|
|
" os.makedirs(args.outdir, exist_ok=True)\n",
|
|
"\n",
|
|
" # current timestring for filenames\n",
|
|
" args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
|
|
"\n",
|
|
" # save settings for the batch\n",
|
|
" if args.save_settings:\n",
|
|
" filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n",
|
|
" with open(filename, \"w+\", encoding=\"utf-8\") as f:\n",
|
|
" json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n",
|
|
"\n",
|
|
" for batch_index in range(args.n_batch):\n",
|
|
"\n",
|
|
" # random seed\n",
|
|
" if args.seed == -1:\n",
|
|
" local_seed = np.random.randint(0,4294967295)\n",
|
|
" else:\n",
|
|
" local_seed = args.seed\n",
|
|
"\n",
|
|
" print(f\"run {batch_index+1} of {args.n_batch}\")\n",
|
|
" run(args, local_seed)\n",
|
|
"\n",
|
|
"do_batch_run()"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "Deforum_Stable_Diffusion.ipynb",
|
|
"provenance": [],
|
|
"private_outputs": true
|
|
},
|
|
"gpuClass": "standard",
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
} |