From aeab8d8a9d35dbba3780c76e4d4dcaccfb475954 Mon Sep 17 00:00:00 2001 From: deforum <110359942+deforum@users.noreply.github.com> Date: Tue, 30 Aug 2022 20:09:49 -0700 Subject: [PATCH] v0.2 --- Deforum_Stable_Diffusion.ipynb | 139 +++++++++++++++------------------ Deforum_Stable_Diffusion.py | 12 +-- 2 files changed, 68 insertions(+), 83 deletions(-) diff --git a/Deforum_Stable_Diffusion.ipynb b/Deforum_Stable_Diffusion.ipynb index 0bfff80e..4ed43234 100644 --- a/Deforum_Stable_Diffusion.ipynb +++ b/Deforum_Stable_Diffusion.ipynb @@ -12,6 +12,15 @@ "Notebook by [deforum](https://discord.gg/upmXXsrwZc)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "T4knibRpAQ06" + }, + "source": [ + "# Setup" + ] + }, { "cell_type": "code", "metadata": { @@ -34,7 +43,7 @@ "id": "TxIOPT0G5Lx1" }, "source": [ - "#@markdown **Model Path Variables**\n", + "#@markdown **Model and Output Paths**\n", "# ask for the link\n", "print(\"Local Path Variables:\\n\")\n", "\n", @@ -165,7 +174,7 @@ " return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n", "\n", "def get_output_folder(output_path, batch_folder):\n", - " out_path = os.path.join(output_path,time.strftime('%Y-%m/'))\n", + " out_path = os.path.join(output_path,time.strftime('%Y-%m'))\n", " if batch_folder != \"\":\n", " out_path = os.path.join(out_path, batch_folder)\n", " os.makedirs(out_path, exist_ok=True)\n", @@ -355,8 +364,7 @@ "id": "CIUJ7lWI4v53" }, "source": [ - "#@markdown **Select Model**\n", - "print(\"\\nSelect Model:\\n\")\n", + "#@markdown **Select and Load Model**\n", "\n", "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n", @@ -365,6 +373,9 @@ "\n", "check_sha256 = True #@param {type:\"boolean\"}\n", "\n", + "load_on_run_all = True #@param {type: 'boolean'}\n", + "half_precision = True # needs to be fixed\n", + "\n", "model_map = {\n", " \"sd-v1-4-full-ema.ckpt\": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},\n", " \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n", @@ -376,34 +387,27 @@ " \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n", "}\n", "\n", - "def wget(url, outputdir):\n", - " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", - " print(res)\n", - "\n", - "def download_model(model_checkpoint):\n", - " download_link = model_map[model_checkpoint][\"link\"][0]\n", - " print(f\"!wget -O {models_path}/{model_checkpoint} {download_link}\")\n", - " wget(download_link, models_path)\n", - " return\n", - "\n", "# config path\n", - "if os.path.exists(models_path+'/'+model_config):\n", - " print(f\"{models_path+'/'+model_config} exists\")\n", + "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", + "if os.path.exists(ckpt_config_path):\n", + " print(f\"{ckpt_config_path} exists\")\n", "else:\n", - " print(\"cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.\")\n", - " shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path)\n", + " ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", + "print(f\"Using config: {ckpt_config_path}\")\n", "\n", "# checkpoint path or download\n", - "if os.path.exists(models_path+'/'+model_checkpoint):\n", - " print(f\"{models_path+'/'+model_checkpoint} exists\")\n", + "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", + "ckpt_valid = True\n", + "if os.path.exists(ckpt_path):\n", + " print(f\"{ckpt_path} exists\")\n", "else:\n", - " print(f\"download model checkpoint and place in {models_path+'/'+model_checkpoint}\")\n", - " #download_model(model_checkpoint)\n", + " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", + " ckpt_valid = False\n", "\n", - "if check_sha256 and model_checkpoint != \"custom\":\n", + "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", " import hashlib\n", " print(\"\\n...checking sha256\")\n", - " with open(models_path+'/'+model_checkpoint, \"rb\") as f:\n", + " with open(ckpt_path, \"rb\") as f:\n", " bytes = f.read() \n", " hash = hashlib.sha256(bytes).hexdigest()\n", " del bytes\n", @@ -411,31 +415,10 @@ " print(\"hash is correct\\n\")\n", " else:\n", " print(\"hash in not correct\\n\")\n", + " ckpt_valid = False\n", "\n", - "if model_config == \"custom\":\n", - " config = custom_config_path\n", - "else:\n", - " config = models_path+'/'+model_config\n", - "\n", - "if model_checkpoint == \"custom\":\n", - " ckpt = custom_checkpoint_path\n", - "else:\n", - " ckpt = models_path+'/'+model_checkpoint\n", - "\n", - "print(f\"config: {config}\")\n", - "print(f\"ckpt: {ckpt}\")" - ], - "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "IJiMgz_96nr3" - }, - "source": [ - "#@markdown **Load Stable Diffusion**\n", + "if ckpt_valid:\n", + " print(f\"Using ckpt: {ckpt_path}\")\n", "\n", "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", @@ -453,7 +436,6 @@ " print(\"unexpected keys:\")\n", " print(u)\n", "\n", - " #model.cuda()\n", " if half_precision:\n", " model = model.half().to(device)\n", " else:\n", @@ -461,15 +443,11 @@ " model.eval()\n", " return model\n", "\n", - "load_on_run_all = True #@param {type: 'boolean'}\n", - "half_precision = True # needs to be fixed\n", - "\n", - "if load_on_run_all:\n", - "\n", - " local_config = OmegaConf.load(f\"{config}\")\n", - " model = load_model_from_config(local_config, f\"{ckpt}\",half_precision=half_precision)\n", - " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", - " model = model.to(device)" + "if load_on_run_all and ckpt_valid:\n", + " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", + " model = load_model_from_config(local_config, f\"{ckpt_path}\",half_precision=half_precision)\n", + " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " model = model.to(device)" ], "outputs": [], "execution_count": null @@ -516,10 +494,10 @@ " translation_y = \"0: (0)\"#@param {type:\"string\"}\n", " noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n", " strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n", - " scale_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n", + " contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n", "\n", " #@markdown ####**Coherence:**\n", - " color_coherence = 'Match Frame 0 HSV' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", + " color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", "\n", " #@markdown ####**Video Input:**\n", " video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", @@ -591,7 +569,7 @@ " translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))\n", " noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))\n", " strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))\n", - " scale_schedule_series = get_inbetweens(parse_key_frames(anim_args.scale_schedule))" + " contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule))" ], "outputs": [], "execution_count": null @@ -649,7 +627,6 @@ " #@markdown **Save & Display Settings**\n", " batch_name = \"StableFun\" #@param {type:\"string\"}\n", " outdir = get_output_folder(output_path, batch_name)\n", - " save_grid = False\n", " save_settings = True #@param {type:\"boolean\"}\n", " save_samples = True #@param {type:\"boolean\"}\n", " display_samples = True #@param {type:\"boolean\"}\n", @@ -695,15 +672,6 @@ "\n", " return locals()\n", "\n", - "def next_seed(args):\n", - " if args.seed_behavior == 'iter':\n", - " args.seed += 1\n", - " elif args.seed_behavior == 'fixed':\n", - " pass # always keep seed the same\n", - " else:\n", - " args.seed = random.randint(0, 2**32)\n", - " return args.seed\n", - "\n", "\n", "args = SimpleNamespace(**DeforumArgs())\n", "args.timestring = time.strftime('%Y%m%d%H%M%S')\n", @@ -724,6 +692,15 @@ " args.ddim_eta = 0\n", "\n", "\n", + "def next_seed(args):\n", + " if args.seed_behavior == 'iter':\n", + " args.seed += 1\n", + " elif args.seed_behavior == 'fixed':\n", + " pass # always keep seed the same\n", + " else:\n", + " args.seed = random.randint(0, 2**32)\n", + " return args.seed\n", + "\n", "def render_image_batch(args):\n", " args.prompts = prompts\n", " \n", @@ -758,12 +735,17 @@ " else:\n", " init_array = [\"\"]\n", "\n", + " # when doing large batches don't flood browser with images\n", + " clear_between_batches = args.n_batch >= 32\n", + "\n", " for iprompt, prompt in enumerate(prompts): \n", " args.prompt = prompt\n", "\n", " all_images = []\n", "\n", " for batch_index in range(args.n_batch):\n", + " if clear_between_batches: \n", + " display.clear_output(wait=True) \n", " print(f\"Batch {batch_index+1} of {args.n_batch}\")\n", " \n", " for image in init_array: # iterates the init images\n", @@ -785,7 +767,10 @@ " grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n", " grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", " filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n", - " Image.fromarray(grid.astype(np.uint8)).save(os.path.join(args.outdir, filename))\n", + " grid_image = Image.fromarray(grid.astype(np.uint8))\n", + " grid_image.save(os.path.join(args.outdir, filename))\n", + " display.clear_output(wait=True) \n", + " display.display(grid_image)\n", "\n", "\n", "def render_animation(args, anim_args):\n", @@ -845,7 +830,7 @@ " translation_y = translation_y_series[frame_idx]\n", " noise = noise_schedule_series[frame_idx]\n", " strength = strength_schedule_series[frame_idx]\n", - " scale = scale_schedule_series[frame_idx]\n", + " contrast = contrast_schedule_series[frame_idx]\n", " print(\n", " f'angle: {angle}',\n", " f'zoom: {zoom}',\n", @@ -853,7 +838,7 @@ " f'translation_y: {translation_y}',\n", " f'noise: {noise}',\n", " f'strength: {strength}',\n", - " f'scale: {scale}',\n", + " f'contrast: {contrast}',\n", " )\n", " xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom)\n", "\n", @@ -874,9 +859,9 @@ " prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n", "\n", " # apply scaling\n", - " scaled_sample = prev_img * scale\n", + " contrast_sample = prev_img * contrast\n", " # apply frame noising\n", - " noised_sample = add_noise(sample_from_cv2(scaled_sample), noise)\n", + " noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)\n", "\n", " # use transformed previous frame as init for current\n", " args.use_init = True\n", @@ -1126,4 +1111,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/Deforum_Stable_Diffusion.py b/Deforum_Stable_Diffusion.py index 6060f3ed..2a495985 100644 --- a/Deforum_Stable_Diffusion.py +++ b/Deforum_Stable_Diffusion.py @@ -462,7 +462,7 @@ def DeforumAnimArgs(): translation_y = "0: (0)"#@param {type:"string"} noise_schedule = "0: (0.02)"#@param {type:"string"} strength_schedule = "0: (0.65)"#@param {type:"string"} - scale_schedule = "0: (1.0)"#@param {type:"string"} + contrast_schedule = "0: (1.0)"#@param {type:"string"} #@markdown ####**Coherence:** color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} @@ -537,7 +537,7 @@ if anim_args.key_frames: translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y)) noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule)) strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule)) - scale_schedule_series = get_inbetweens(parse_key_frames(anim_args.scale_schedule)) + contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule)) # %% # !! {"metadata":{ @@ -786,7 +786,7 @@ def render_animation(args, anim_args): translation_y = translation_y_series[frame_idx] noise = noise_schedule_series[frame_idx] strength = strength_schedule_series[frame_idx] - scale = scale_schedule_series[frame_idx] + contrast = contrast_schedule_series[frame_idx] print( f'angle: {angle}', f'zoom: {zoom}', @@ -794,7 +794,7 @@ def render_animation(args, anim_args): f'translation_y: {translation_y}', f'noise: {noise}', f'strength: {strength}', - f'scale: {scale}', + f'contrast: {contrast}', ) xform = make_xform_2d(args.W, args.H, translation_x, translation_y, angle, zoom) @@ -815,9 +815,9 @@ def render_animation(args, anim_args): prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) # apply scaling - scaled_sample = prev_img * scale + contrast_sample = prev_img * contrast # apply frame noising - noised_sample = add_noise(sample_from_cv2(scaled_sample), noise) + noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) # use transformed previous frame as init for current args.use_init = True