border
deforum 2022-08-30 20:09:49 -07:00 committed by GitHub
parent d58294e65e
commit aeab8d8a9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 83 deletions

View File

@ -12,6 +12,15 @@
"Notebook by [deforum](https://discord.gg/upmXXsrwZc)" "Notebook by [deforum](https://discord.gg/upmXXsrwZc)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "T4knibRpAQ06"
},
"source": [
"# Setup"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
@ -34,7 +43,7 @@
"id": "TxIOPT0G5Lx1" "id": "TxIOPT0G5Lx1"
}, },
"source": [ "source": [
"#@markdown **Model Path Variables**\n", "#@markdown **Model and Output Paths**\n",
"# ask for the link\n", "# ask for the link\n",
"print(\"Local Path Variables:\\n\")\n", "print(\"Local Path Variables:\\n\")\n",
"\n", "\n",
@ -165,7 +174,7 @@
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n", " return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
"\n", "\n",
"def get_output_folder(output_path, batch_folder):\n", "def get_output_folder(output_path, batch_folder):\n",
" out_path = os.path.join(output_path,time.strftime('%Y-%m/'))\n", " out_path = os.path.join(output_path,time.strftime('%Y-%m'))\n",
" if batch_folder != \"\":\n", " if batch_folder != \"\":\n",
" out_path = os.path.join(out_path, batch_folder)\n", " out_path = os.path.join(out_path, batch_folder)\n",
" os.makedirs(out_path, exist_ok=True)\n", " os.makedirs(out_path, exist_ok=True)\n",
@ -355,8 +364,7 @@
"id": "CIUJ7lWI4v53" "id": "CIUJ7lWI4v53"
}, },
"source": [ "source": [
"#@markdown **Select Model**\n", "#@markdown **Select and Load Model**\n",
"print(\"\\nSelect Model:\\n\")\n",
"\n", "\n",
"model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
"model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n", "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n",
@ -365,6 +373,9 @@
"\n", "\n",
"check_sha256 = True #@param {type:\"boolean\"}\n", "check_sha256 = True #@param {type:\"boolean\"}\n",
"\n", "\n",
"load_on_run_all = True #@param {type: 'boolean'}\n",
"half_precision = True # needs to be fixed\n",
"\n",
"model_map = {\n", "model_map = {\n",
" \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n", " \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n",
" \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n", " \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n",
@ -376,34 +387,27 @@
" \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n", " \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n",
"}\n", "}\n",
"\n", "\n",
"def wget(url, outputdir):\n",
" res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
" print(res)\n",
"\n",
"def download_model(model_checkpoint):\n",
" download_link = model_map[model_checkpoint][\"link\"][0]\n",
" print(f\"!wget -O {models_path}/{model_checkpoint} {download_link}\")\n",
" wget(download_link, models_path)\n",
" return\n",
"\n",
"# config path\n", "# config path\n",
"if os.path.exists(models_path+'/'+model_config):\n", "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n",
" print(f\"{models_path+'/'+model_config} exists\")\n", "if os.path.exists(ckpt_config_path):\n",
" print(f\"{ckpt_config_path} exists\")\n",
"else:\n", "else:\n",
" print(\"cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.\")\n", " ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n",
" shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path)\n", "print(f\"Using config: {ckpt_config_path}\")\n",
"\n", "\n",
"# checkpoint path or download\n", "# checkpoint path or download\n",
"if os.path.exists(models_path+'/'+model_checkpoint):\n", "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n",
" print(f\"{models_path+'/'+model_checkpoint} exists\")\n", "ckpt_valid = True\n",
"if os.path.exists(ckpt_path):\n",
" print(f\"{ckpt_path} exists\")\n",
"else:\n", "else:\n",
" print(f\"download model checkpoint and place in {models_path+'/'+model_checkpoint}\")\n", " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n",
" #download_model(model_checkpoint)\n", " ckpt_valid = False\n",
"\n", "\n",
"if check_sha256 and model_checkpoint != \"custom\":\n", "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n",
" import hashlib\n", " import hashlib\n",
" print(\"\\n...checking sha256\")\n", " print(\"\\n...checking sha256\")\n",
" with open(models_path+'/'+model_checkpoint, \"rb\") as f:\n", " with open(ckpt_path, \"rb\") as f:\n",
" bytes = f.read() \n", " bytes = f.read() \n",
" hash = hashlib.sha256(bytes).hexdigest()\n", " hash = hashlib.sha256(bytes).hexdigest()\n",
" del bytes\n", " del bytes\n",
@ -411,31 +415,10 @@
" print(\"hash is correct\\n\")\n", " print(\"hash is correct\\n\")\n",
" else:\n", " else:\n",
" print(\"hash in not correct\\n\")\n", " print(\"hash in not correct\\n\")\n",
" ckpt_valid = False\n",
"\n", "\n",
"if model_config == \"custom\":\n", "if ckpt_valid:\n",
" config = custom_config_path\n", " print(f\"Using ckpt: {ckpt_path}\")\n",
"else:\n",
" config = models_path+'/'+model_config\n",
"\n",
"if model_checkpoint == \"custom\":\n",
" ckpt = custom_checkpoint_path\n",
"else:\n",
" ckpt = models_path+'/'+model_checkpoint\n",
"\n",
"print(f\"config: {config}\")\n",
"print(f\"ckpt: {ckpt}\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "IJiMgz_96nr3"
},
"source": [
"#@markdown **Load Stable Diffusion**\n",
"\n", "\n",
"def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n",
" map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n",
@ -453,7 +436,6 @@
" print(\"unexpected keys:\")\n", " print(\"unexpected keys:\")\n",
" print(u)\n", " print(u)\n",
"\n", "\n",
" #model.cuda()\n",
" if half_precision:\n", " if half_precision:\n",
" model = model.half().to(device)\n", " model = model.half().to(device)\n",
" else:\n", " else:\n",
@ -461,13 +443,9 @@
" model.eval()\n", " model.eval()\n",
" return model\n", " return model\n",
"\n", "\n",
"load_on_run_all = True #@param {type: 'boolean'}\n", "if load_on_run_all and ckpt_valid:\n",
"half_precision = True # needs to be fixed\n", " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n",
"\n", " model = load_model_from_config(local_config, f\"{ckpt_path}\",half_precision=half_precision)\n",
"if load_on_run_all:\n",
"\n",
" local_config = OmegaConf.load(f\"{config}\")\n",
" model = load_model_from_config(local_config, f\"{ckpt}\",half_precision=half_precision)\n",
" device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
" model = model.to(device)" " model = model.to(device)"
], ],
@ -516,10 +494,10 @@
" translation_y = \"0: (0)\"#@param {type:\"string\"}\n", " translation_y = \"0: (0)\"#@param {type:\"string\"}\n",
" noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n", " noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n",
" strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n", " strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n",
" scale_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n", " contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
"\n", "\n",
" #@markdown ####**Coherence:**\n", " #@markdown ####**Coherence:**\n",
" color_coherence = 'Match Frame 0 HSV' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", " color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
"\n", "\n",
" #@markdown ####**Video Input:**\n", " #@markdown ####**Video Input:**\n",
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", " video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
@ -591,7 +569,7 @@
" translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))\n", " translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))\n",
" noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))\n", " noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))\n",
" strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))\n", " strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))\n",
" scale_schedule_series = get_inbetweens(parse_key_frames(anim_args.scale_schedule))" " contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule))"
], ],
"outputs": [], "outputs": [],
"execution_count": null "execution_count": null
@ -649,7 +627,6 @@
" #@markdown **Save & Display Settings**\n", " #@markdown **Save & Display Settings**\n",
" batch_name = \"StableFun\" #@param {type:\"string\"}\n", " batch_name = \"StableFun\" #@param {type:\"string\"}\n",
" outdir = get_output_folder(output_path, batch_name)\n", " outdir = get_output_folder(output_path, batch_name)\n",
" save_grid = False\n",
" save_settings = True #@param {type:\"boolean\"}\n", " save_settings = True #@param {type:\"boolean\"}\n",
" save_samples = True #@param {type:\"boolean\"}\n", " save_samples = True #@param {type:\"boolean\"}\n",
" display_samples = True #@param {type:\"boolean\"}\n", " display_samples = True #@param {type:\"boolean\"}\n",
@ -695,15 +672,6 @@
"\n", "\n",
" return locals()\n", " return locals()\n",
"\n", "\n",
"def next_seed(args):\n",
" if args.seed_behavior == 'iter':\n",
" args.seed += 1\n",
" elif args.seed_behavior == 'fixed':\n",
" pass # always keep seed the same\n",
" else:\n",
" args.seed = random.randint(0, 2**32)\n",
" return args.seed\n",
"\n",
"\n", "\n",
"args = SimpleNamespace(**DeforumArgs())\n", "args = SimpleNamespace(**DeforumArgs())\n",
"args.timestring = time.strftime('%Y%m%d%H%M%S')\n", "args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
@ -724,6 +692,15 @@
" args.ddim_eta = 0\n", " args.ddim_eta = 0\n",
"\n", "\n",
"\n", "\n",
"def next_seed(args):\n",
" if args.seed_behavior == 'iter':\n",
" args.seed += 1\n",
" elif args.seed_behavior == 'fixed':\n",
" pass # always keep seed the same\n",
" else:\n",
" args.seed = random.randint(0, 2**32)\n",
" return args.seed\n",
"\n",
"def render_image_batch(args):\n", "def render_image_batch(args):\n",
" args.prompts = prompts\n", " args.prompts = prompts\n",
" \n", " \n",
@ -758,12 +735,17 @@
" else:\n", " else:\n",
" init_array = [\"\"]\n", " init_array = [\"\"]\n",
"\n", "\n",
" # when doing large batches don't flood browser with images\n",
" clear_between_batches = args.n_batch >= 32\n",
"\n",
" for iprompt, prompt in enumerate(prompts): \n", " for iprompt, prompt in enumerate(prompts): \n",
" args.prompt = prompt\n", " args.prompt = prompt\n",
"\n", "\n",
" all_images = []\n", " all_images = []\n",
"\n", "\n",
" for batch_index in range(args.n_batch):\n", " for batch_index in range(args.n_batch):\n",
" if clear_between_batches: \n",
" display.clear_output(wait=True) \n",
" print(f\"Batch {batch_index+1} of {args.n_batch}\")\n", " print(f\"Batch {batch_index+1} of {args.n_batch}\")\n",
" \n", " \n",
" for image in init_array: # iterates the init images\n", " for image in init_array: # iterates the init images\n",
@ -785,7 +767,10 @@
" grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n", " grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n",
" grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", " grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
" filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n", " filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n",
" Image.fromarray(grid.astype(np.uint8)).save(os.path.join(args.outdir, filename))\n", " grid_image = Image.fromarray(grid.astype(np.uint8))\n",
" grid_image.save(os.path.join(args.outdir, filename))\n",
" display.clear_output(wait=True) \n",
" display.display(grid_image)\n",
"\n", "\n",
"\n", "\n",
"def render_animation(args, anim_args):\n", "def render_animation(args, anim_args):\n",
@ -845,7 +830,7 @@
" translation_y = translation_y_series[frame_idx]\n", " translation_y = translation_y_series[frame_idx]\n",
" noise = noise_schedule_series[frame_idx]\n", " noise = noise_schedule_series[frame_idx]\n",
" strength = strength_schedule_series[frame_idx]\n", " strength = strength_schedule_series[frame_idx]\n",
" scale = scale_schedule_series[frame_idx]\n", " contrast = contrast_schedule_series[frame_idx]\n",
" print(\n", " print(\n",
" f'angle: {angle}',\n", " f'angle: {angle}',\n",
" f'zoom: {zoom}',\n", " f'zoom: {zoom}',\n",
@ -853,7 +838,7 @@
" f'translation_y: {translation_y}',\n", " f'translation_y: {translation_y}',\n",
" f'noise: {noise}',\n", " f'noise: {noise}',\n",
" f'strength: {strength}',\n", " f'strength: {strength}',\n",
" f'scale: {scale}',\n", " f'contrast: {contrast}',\n",
" )\n", " )\n",
" xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)\n", " xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)\n",
"\n", "\n",
@ -874,9 +859,9 @@
" prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n", " prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n",
"\n", "\n",
" # apply scaling\n", " # apply scaling\n",
" scaled_sample = prev_img * scale\n", " contrast_sample = prev_img * contrast\n",
" # apply frame noising\n", " # apply frame noising\n",
" noised_sample = add_noise(sample_from_cv2(scaled_sample), noise)\n", " noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)\n",
"\n", "\n",
" # use transformed previous frame as init for current\n", " # use transformed previous frame as init for current\n",
" args.use_init = True\n", " args.use_init = True\n",

View File

@ -462,7 +462,7 @@ def DeforumAnimArgs():
translation_y = "0: (0)"#@param {type:"string"} translation_y = "0: (0)"#@param {type:"string"}
noise_schedule = "0: (0.02)"#@param {type:"string"} noise_schedule = "0: (0.02)"#@param {type:"string"}
strength_schedule = "0: (0.65)"#@param {type:"string"} strength_schedule = "0: (0.65)"#@param {type:"string"}
scale_schedule = "0: (1.0)"#@param {type:"string"} contrast_schedule = "0: (1.0)"#@param {type:"string"}
#@markdown ####**Coherence:** #@markdown ####**Coherence:**
color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
@ -537,7 +537,7 @@ if anim_args.key_frames:
translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y)) translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))
noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule)) noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))
strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule)) strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))
scale_schedule_series = get_inbetweens(parse_key_frames(anim_args.scale_schedule)) contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule))
# %% # %%
# !! {"metadata":{ # !! {"metadata":{
@ -786,7 +786,7 @@ def render_animation(args, anim_args):
translation_y = translation_y_series[frame_idx] translation_y = translation_y_series[frame_idx]
noise = noise_schedule_series[frame_idx] noise = noise_schedule_series[frame_idx]
strength = strength_schedule_series[frame_idx] strength = strength_schedule_series[frame_idx]
scale = scale_schedule_series[frame_idx] contrast = contrast_schedule_series[frame_idx]
print( print(
f'angle: {angle}', f'angle: {angle}',
f'zoom: {zoom}', f'zoom: {zoom}',
@ -794,7 +794,7 @@ def render_animation(args, anim_args):
f'translation_y: {translation_y}', f'translation_y: {translation_y}',
f'noise: {noise}', f'noise: {noise}',
f'strength: {strength}', f'strength: {strength}',
f'scale: {scale}', f'contrast: {contrast}',
) )
xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom) xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)
@ -815,9 +815,9 @@ def render_animation(args, anim_args):
prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)
# apply scaling # apply scaling
scaled_sample = prev_img * scale contrast_sample = prev_img * contrast
# apply frame noising # apply frame noising
noised_sample = add_noise(sample_from_cv2(scaled_sample), noise) noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)
# use transformed previous frame as init for current # use transformed previous frame as init for current
args.use_init = True args.use_init = True