pull/87/head v0.7
Alexey Borsky 2023-05-07 22:32:30 +03:00
parent a553de32db
commit 4adcaf026b
9 changed files with 997 additions and 562 deletions

143
FloweR/model.py Normal file
View File

@ -0,0 +1,143 @@
import torch
import torch.nn as nn
import torch.functional as F
# Define the model
class FloweR(nn.Module):
def __init__(self, input_size = (384, 384), window_size = 4):
super(FloweR, self).__init__()
self.input_size = input_size
self.window_size = window_size
#INPUT: 384 x 384 x 10 * 3
### DOWNSCALE ###
self.conv_block_1 = nn.Sequential(
nn.Conv2d(3 * self.window_size, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 384 x 384 x 128
self.conv_block_2 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 192 x 192 x 128
self.conv_block_3 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 96 x 96 x 128
self.conv_block_4 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 48 x 48 x 128
self.conv_block_5 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 24 x 24 x 128
self.conv_block_6 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 12 x 12 x 128
self.conv_block_7 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 6 x 6 x 128
self.conv_block_8 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 3 x 3 x 128
### UPSCALE ###
self.conv_block_9 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 6 x 6 x 128
self.conv_block_10 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 12 x 12 x 128
self.conv_block_11 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 24 x 24 x 128
self.conv_block_12 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 48 x 48 x 128
self.conv_block_13 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 96 x 96 x 128
self.conv_block_14 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 192 x 192 x 128
self.conv_block_15 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 384 x 384 x 128
self.conv_block_16 = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding='same')
def forward(self, x):
if x.size(1) != self.window_size:
raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.')
# batch, frames, height, width, colors
in_x = x.permute((0, 1, 4, 2, 3))
# batch, frames, colors, height, width
in_x = in_x.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1])
### DOWNSCALE ###
block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128
block_2_out = self.conv_block_2(block_1_out) # 192 x 192 x 128
block_3_out = self.conv_block_3(block_2_out) # 96 x 96 x 128
block_4_out = self.conv_block_4(block_3_out) # 48 x 48 x 128
block_5_out = self.conv_block_5(block_4_out) # 24 x 24 x 128
block_6_out = self.conv_block_6(block_5_out) # 12 x 12 x 128
block_7_out = self.conv_block_7(block_6_out) # 6 x 6 x 128
block_8_out = self.conv_block_8(block_7_out) # 3 x 3 x 128
### UPSCALE ###
block_9_out = block_7_out + self.conv_block_9(block_8_out) # 6 x 6 x 128
block_10_out = block_6_out + self.conv_block_10(block_9_out) # 12 x 12 x 128
block_11_out = block_5_out + self.conv_block_11(block_10_out) # 24 x 24 x 128
block_12_out = block_4_out + self.conv_block_12(block_11_out) # 48 x 48 x 128
block_13_out = block_3_out + self.conv_block_13(block_12_out) # 96 x 96 x 128
block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128
block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128
block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1)
out = block_16_out.reshape(-1, 3, self.input_size[0], self.input_size[1])
# batch, colors, height, width
out = out.permute((0, 2, 3, 1))
# batch, height, width, colors
return out

BIN
examples/ui_preview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 631 KiB

View File

@ -1,6 +1,10 @@
# SD-CN-Animation
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an occlusion mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
![sd-cn-animation ui preview](examples/ui_preview.png)
sd-cn-animation ui preview
### Video to Video Examples:
</table>
<table class="center">
@ -46,11 +50,10 @@ Examples presented are generated at 1024x576 resolution using the 'realisticVisi
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
## Installing the extension
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Download RAFT 'raft-things.pth' model from here: [Google Drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT) and place it into 'stable-diffusion-webui/models/RAFT/' folder. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
## Last version changes: v0.6
* Complete rewrite of the project to make it possible to install as an Automatic1111/Web-ui extension.
* Added flow normalization before resizing it, so the magnitude of the flow computed correctly at the different resolution.
* Less ghosting and color drift in vid2vid mode
* Added "warped styled frame fix" at vid2vid mode that removes duplicates from the parts of the image that cannot be relocated from the optical flow.
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
## Last version changes: v0.7
* Text to Video mode added to the extension
* 'Generate' button is now automatically disabled while the video is generated
* Added 'Interrupt' button that allows to stop video generation process
* Now all necessary models are automatically downloaded. No need for manual preparation.

View File

@ -13,6 +13,7 @@ for basedir in basedirs:
sys.path.extend([scripts_path_fix])
import gradio as gr
import modules
from types import SimpleNamespace
from modules import script_callbacks, shared
@ -26,8 +27,7 @@ import modules.scripts as scripts
from modules.sd_samplers import samplers_for_img2img
from modules.ui import setup_progressbar, create_sampler_and_steps_selection, ordered_ui_categories, create_output_panel
from vid2vid import *
from core import vid2vid, txt2vid, utils
def V2VArgs():
seed = -1
@ -75,41 +75,59 @@ def setup_common_values(mode, d):
def inputs_ui():
v2v_args = SimpleNamespace(**V2VArgs())
t2v_args = SimpleNamespace(**T2VArgs())
with gr.Tab('vid2vid') as tab_vid2vid:
with gr.Row():
gr.HTML('Put your video here')
with gr.Row():
vid2vid_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
#init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", image_mode="RGBA")
#with gr.Row():
# gr.HTML('Alternative: enter the relative (to the webui) path to the file')
#with gr.Row():
# vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
with gr.Tabs():
sdcn_process_mode = gr.State(value='vid2vid')
width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength = setup_common_values('vid2vid', v2v_args)
with gr.Tab('vid2vid') as tab_vid2vid:
with gr.Row():
gr.HTML('Put your video here')
with gr.Row():
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
#init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", image_mode="RGBA")
#with gr.Row():
# gr.HTML('Alternative: enter the relative (to the webui) path to the file')
#with gr.Row():
# vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
with FormRow(elem_id=f"sampler_selection_v2v"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"v2v_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"v2v_steps", label="Sampling steps", value=15)
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength = setup_common_values('vid2vid', v2v_args)
with FormRow(elem_id="vid2vid_override_settings_row") as row:
override_settings = create_override_settings_dropdown("vid2vid", row)
with FormRow(elem_id=f"sampler_selection_v2v"):
v2v_sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"v2v_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
v2v_steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"v2v_steps", label="Sampling steps", value=15)
with FormGroup(elem_id=f"script_container"):
custom_inputs = scripts.scripts_img2img.setup_ui()
#with gr.Row():
# strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
# vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
with FormRow(elem_id="vid2vid_override_settings_row") as row:
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
with FormGroup(elem_id=f"script_container"):
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
#with gr.Row():
# strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
# vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
with gr.Tab('txt2vid') as tab_txt2vid:
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength = setup_common_values('txt2vid', t2v_args)
with gr.Row():
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
with gr.Tab('txt2vid') as tab_txt2vid:
gr.Markdown('Work in progress...')
# width, height, prompt, n_prompt, steps, cfg_scale, seed, processing_strength, fix_frame_strength = setup_common_values('txt2vid', t2v_args)
#with gr.Tab('settings') as tab_setts:
# gr.Markdown('Work in progress...')
tab_vid2vid.select(fn=lambda: 'vid2vid', inputs=[], outputs=[sdcn_process_mode])
tab_txt2vid.select(fn=lambda: 'txt2vid', inputs=[], outputs=[sdcn_process_mode])
return locals()
def process(*args):
if args[0] == 'vid2vid':
yield from vid2vid.start_process(*args)
elif args[0] == 'txt2vid':
yield from txt2vid.start_process(*args)
else:
raise Exception(f"Unsupported processing mode: '{args[0]}'")
def stop_process(*args):
utils.shared.is_interrupted = True
return gr.Button.update(interactive=False)
def on_ui_tabs():
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
@ -118,11 +136,8 @@ def on_ui_tabs():
components = {}
#dv = SimpleNamespace(**T2VOutputArgs())
with gr.Row(elem_id='v2v-core').style(equal_height=False, variant='compact'):
with gr.Row(elem_id='sdcn-core').style(equal_height=False, variant='compact'):
with gr.Column(scale=1, variant='panel'):
with gr.Row(variant='compact'):
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
with gr.Tabs():
components = inputs_ui()
@ -139,6 +154,10 @@ def on_ui_tabs():
# custom_inputs = scripts.scripts_img2img.setup_ui()
with gr.Column(scale=1, variant='compact'):
with gr.Row(variant='compact'):
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False)
with gr.Column(variant="panel"):
sp_progress = gr.HTML(elem_id="sp_progress", value="")
sp_progress.update()
@ -157,52 +176,9 @@ def on_ui_tabs():
with gr.Row(variant='compact'):
dummy_component = gr.Label(visible=False)
# Define parameters for the action methods. Not all of them are included yet
method_inputs = [
dummy_component, # send None for task_id
dummy_component, # mode
components['prompt'], # prompt
components['n_prompt'], # negative_prompt
dummy_component, # prompt_styles
components['vid2vid_file'], # input_video
dummy_component, # sketch
dummy_component, # init_img_with_mask
dummy_component, # inpaint_color_sketch
dummy_component, # inpaint_color_sketch_orig
dummy_component, # init_img_inpaint
dummy_component, # init_mask_inpaint
components['steps'], # steps
components['sampler_index'], # sampler_index
dummy_component, # mask_blur
dummy_component, # mask_alpha
dummy_component, # inpainting_fill
dummy_component, # restore_faces
dummy_component, # tiling
dummy_component, # n_iter
dummy_component, # batch_size
components['cfg_scale'], # cfg_scale
dummy_component, # image_cfg_scale
components['processing_strength'], # denoising_strength
components['fix_frame_strength'], # fix_frame_strength
components['seed'], # seed
dummy_component, # subseed
dummy_component, # subseed_strength
dummy_component, # seed_resize_from_h
dummy_component, # seed_resize_from_w
dummy_component, # seed_enable_extras
components['height'], # height
components['width'], # width
dummy_component, # resize_mode
dummy_component, # inpaint_full_res
dummy_component, # inpaint_full_res_padding
dummy_component, # inpainting_mask_invert
dummy_component, # img2img_batch_input_dir
dummy_component, # img2img_batch_output_dir
dummy_component, # img2img_batch_inpaint_mask_dir
components['override_settings'], # override_settings_texts
] + components['custom_inputs']
# Define parameters for the action methods.
method_inputs = [components[name] for name in utils.get_component_names()] + components['v2v_custom_inputs']
method_outputs = [
sp_progress,
@ -211,15 +187,23 @@ def on_ui_tabs():
img_preview_prev_warp,
img_preview_processed,
html_log,
run_button,
stop_button
]
run_button.click(
fn=start_process, #wrap_gradio_gpu_call(start_process, extra_outputs=[None, '', '']),
fn=process, #wrap_gradio_gpu_call(start_process, extra_outputs=[None, '', '']),
inputs=method_inputs,
outputs=method_outputs,
show_progress=True,
)
stop_button.click(
fn=stop_process,
outputs=[stop_button],
show_progress=False
)
modules.scripts.scripts_current = None
# define queue - required for generators

View File

@ -42,6 +42,15 @@ def RAFT_clear_memory():
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
global RAFT_model
model_path = ph.models_path + '/RAFT/raft-things.pth'
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
if RAFT_model is None:
args = argparse.Namespace(**{
'model': ph.models_path + '/RAFT/raft-things.pth',

150
scripts/core/txt2vid.py Normal file
View File

@ -0,0 +1,150 @@
import sys, os
basedirs = [os.getcwd()]
for basedir in basedirs:
paths_to_ensure = [
basedir,
basedir + '/extensions/sd-cn-animation/scripts',
basedir + '/extensions/SD-CN-Animation/scripts'
]
for scripts_path_fix in paths_to_ensure:
if not scripts_path_fix in sys.path:
sys.path.extend([scripts_path_fix])
import torch
import gc
import numpy as np
from PIL import Image
import modules.paths as ph
from modules.shared import devices
from core import utils, flow_utils
from FloweR.model import FloweR
import skimage
import datetime
import cv2
import gradio as gr
FloweR_model = None
DEVICE = 'cuda'
def FloweR_clear_memory():
global FloweR_model
del FloweR_model
gc.collect()
torch.cuda.empty_cache()
FloweR_model = None
def FloweR_load_model(w, h):
global DEVICE, FloweR_model
model_path = ph.models_path + '/FloweR/FloweR_0.1.1.pth'
remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv'
if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
FloweR_model = FloweR(input_size = (h, w))
FloweR_model.load_state_dict(torch.load(model_path))
# Move the model to the device
DEVICE = devices.get_optimal_device()
FloweR_model = FloweR_model.to(DEVICE)
def start_process(*args):
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.txt2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
init_frame = processed_frame.copy()
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/txt2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
stat = f"Frame: 1 / {args_dict['length']}"
utils.shared.is_interrupted = False
yield stat, init_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
FloweR_load_model(args_dict['width'], args_dict['height'])
clip_frames = np.zeros((4, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
prev_frame = init_frame
try:
for ind in range(args_dict['length']):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
clip_frames = np.roll(clip_frames, -1, axis=0)
clip_frames[-1] = prev_frame
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
with torch.no_grad():
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
flow_map = pred_flow.copy()
flow_map[:,:,0] += np.arange(args_dict['width'])
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
curr_frame = warped_frame.copy()
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['mask_img'] = Image.fromarray(pred_occl)
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
args_dict['mode'] = 0
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['mask_img'] = None
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
prev_frame = processed_frame.copy()
stat = f"Frame: {ind + 2} / {args_dict['length']}"
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
except: pass
output_video.release()
FloweR_clear_memory()
curr_frame = gr.Image.update()
occlusion_mask = gr.Image.update()
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)

370
scripts/core/utils.py Normal file
View File

@ -0,0 +1,370 @@
class shared:
is_interrupted = False
def get_component_names():
components_list = [
'sdcn_process_mode',
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
't2v_length', 't2v_fps'
]
return components_list
def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
args_list = get_component_names()
# set default values for params that were not specified
args_dict = {
# video to video params
'v2v_mode': 0,
'v2v_prompt': '',
'v2v_n_prompt': '',
'v2v_prompt_styles': [],
'v2v_init_video': None, # Always required
'v2v_steps': 15,
'v2v_sampler_index': 0, # 'Euler a'
'v2v_mask_blur': 0,
'v2v_inpainting_fill': 1, # original
'v2v_restore_faces': False,
'v2v_tiling': False,
'v2v_n_iter': 1,
'v2v_batch_size': 1,
'v2v_cfg_scale': 5.5,
'v2v_image_cfg_scale': 1.5,
'v2v_denoising_strength': 0.75,
'v2v_fix_frame_strength': 0.15,
'v2v_seed': -1,
'v2v_subseed': -1,
'v2v_subseed_strength': 0,
'v2v_seed_resize_from_h': 512,
'v2v_seed_resize_from_w': 512,
'v2v_seed_enable_extras': False,
'v2v_height': 512,
'v2v_width': 512,
'v2v_resize_mode': 1,
'v2v_inpaint_full_res': True,
'v2v_inpaint_full_res_padding': 0,
'v2v_inpainting_mask_invert': False,
# text to video params
't2v_mode': 4,
't2v_prompt': '',
't2v_n_prompt': '',
't2v_prompt_styles': [],
't2v_init_img': None,
't2v_mask_img': None,
't2v_steps': 15,
't2v_sampler_index': 0, # 'Euler a'
't2v_mask_blur': 0,
't2v_inpainting_fill': 1, # original
't2v_restore_faces': False,
't2v_tiling': False,
't2v_n_iter': 1,
't2v_batch_size': 1,
't2v_cfg_scale': 5.5,
't2v_image_cfg_scale': 1.5,
't2v_denoising_strength': 0.75,
't2v_fix_frame_strength': 0.15,
't2v_seed': -1,
't2v_subseed': -1,
't2v_subseed_strength': 0,
't2v_seed_resize_from_h': 512,
't2v_seed_resize_from_w': 512,
't2v_seed_enable_extras': False,
't2v_height': 512,
't2v_width': 512,
't2v_resize_mode': 1,
't2v_inpaint_full_res': True,
't2v_inpaint_full_res_padding': 0,
't2v_inpainting_mask_invert': False,
't2v_override_settings': [],
't2v_script_inputs': [0],
't2v_fps': 12,
}
args = list(args)
for i in range(len(args_list)):
if (args[i] is None) and (args_list[i] in args_dict):
args[i] = args_dict[args_list[i]]
else:
args_dict[args_list[i]] = args[i]
args_dict['v2v_script_inputs'] = args[len(args_list):]
return args_dict
def get_mode_args(mode, args_dict):
mode_args_dict = {}
for key, value in args_dict.items():
if key[:3] == mode:
mode_args_dict[key[4:]] = value
return mode_args_dict
def set_CNs_input_image(args_dict, image):
for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit':
script_input.batch_images = [image]
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from types import SimpleNamespace
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, process_images
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
from modules.shared import opts, devices, state
from modules import devices, sd_samplers, img2img
from modules import shared, sd_hijack, lowvram
# TODO: Refactor all the code below
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p)
#images = shared.listfiles(input_dir)
images = [input_img]
is_inpaint_batch = False
#if inpaint_mask_dir:
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
# is_inpaint_batch = len(inpaint_masks) > 0
#if is_inpaint_batch:
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True
p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
generated_images = []
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
if state.interrupted:
break
img = image #Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
#if is_inpaint_batch:
# # try to find corresponding mask for an image using simple filename matching
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# # if not found use first one ("same mask for all images" use-case)
# if not mask_image_path in inpaint_masks:
# mask_image_path = inpaint_masks[0]
# mask_image = Image.open(mask_image_path)
# p.image_mask = mask_image
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
generated_images.append(proc.images[0])
#for n, processed_image in enumerate(proc.images):
# filename = os.path.basename(image)
# if n > 0:
# left, right = os.path.splitext(filename)
# filename = f"{left}-{n}{right}"
# if not save_normally:
# os.makedirs(output_dir, exist_ok=True)
# if processed_image.mode == 'RGBA':
# processed_image = processed_image.convert("RGB")
# processed_image.save(os.path.join(output_dir, filename))
return generated_images
def img2img(args_dict):
args = SimpleNamespace(**args_dict)
override_settings = create_override_settings_dict(args.override_settings)
is_batch = args.mode == 5
if args.mode == 0: # img2img
image = args.init_img.convert("RGB")
mask = None
elif args.mode == 1: # img2img sketch
image = args.sketch.convert("RGB")
mask = None
elif args.mode == 2: # inpaint
image, mask = args.init_img_with_mask["image"], args.init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert("RGB")
elif args.mode == 3: # inpaint sketch
image = args.inpaint_color_sketch
orig = args.inpaint_color_sketch_orig or args.inpaint_color_sketch
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - args.mask_alpha / 100)
blur = ImageFilter.GaussianBlur(args.mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif args.mode == 4: # inpaint upload mask
#image = args.init_img_inpaint
#mask = args.init_mask_inpaint
image = args.init_img.convert("RGB")
mask = args.mask_img.convert("L")
else:
image = None
mask = None
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= args.denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=args.prompt,
negative_prompt=args.n_prompt,
styles=args.prompt_styles,
seed=args.seed,
subseed=args.subseed,
subseed_strength=args.subseed_strength,
seed_resize_from_h=args.seed_resize_from_h,
seed_resize_from_w=args.seed_resize_from_w,
seed_enable_extras=args.seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[args.sampler_index].name,
batch_size=args.batch_size,
n_iter=args.n_iter,
steps=args.steps,
cfg_scale=args.cfg_scale,
width=args.width,
height=args.height,
restore_faces=args.restore_faces,
tiling=args.tiling,
init_images=[image],
mask=mask,
mask_blur=args.mask_blur,
inpainting_fill=args.inpainting_fill,
resize_mode=args.resize_mode,
denoising_strength=args.denoising_strength,
image_cfg_scale=args.image_cfg_scale,
inpaint_full_res=args.inpaint_full_res,
inpaint_full_res_padding=args.inpaint_full_res_padding,
inpainting_mask_invert=args.inpainting_mask_invert,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args.script_inputs
#if shared.cmd_opts.enable_console_prompts:
# print(f"\nimg2img: {args.prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = args.mask_blur
'''
if is_batch:
...
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
# process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args.script_inputs)
# processed = Processed(p, [], p.seed, "")
else:
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
'''
generated_images = process_img(p, image, None, '', args.script_inputs)
processed = Processed(p, [], p.seed, "")
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
#if opts.samples_log_stdout:
# print(generation_info_js)
#if opts.do_not_show_images:
# processed.images = []
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
def txt2img(args_dict):
args = SimpleNamespace(**args_dict)
override_settings = create_override_settings_dict(args.override_settings)
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=args.prompt,
styles=args.prompt_styles,
negative_prompt=args.n_prompt,
seed=args.seed,
subseed=args.subseed,
subseed_strength=args.subseed_strength,
seed_resize_from_h=args.seed_resize_from_h,
seed_resize_from_w=args.seed_resize_from_w,
seed_enable_extras=args.seed_enable_extras,
sampler_name=sd_samplers.samplers[args.sampler_index].name,
batch_size=args.batch_size,
n_iter=args.n_iter,
steps=args.steps,
cfg_scale=args.cfg_scale,
width=args.width,
height=args.height,
restore_faces=args.restore_faces,
tiling=args.tiling,
#enable_hr=args.enable_hr,
#denoising_strength=args.denoising_strength if enable_hr else None,
#hr_scale=hr_scale,
#hr_upscaler=hr_upscaler,
#hr_second_pass_steps=hr_second_pass_steps,
#hr_resize_x=hr_resize_x,
#hr_resize_y=hr_resize_y,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args.script_inputs
#if cmd_opts.enable_console_prompts:
# print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
#if opts.samples_log_stdout:
# print(generation_info_js)
#if opts.do_not_show_images:
# processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)

252
scripts/core/vid2vid.py Normal file
View File

@ -0,0 +1,252 @@
import sys, os
basedirs = [os.getcwd()]
for basedir in basedirs:
paths_to_ensure = [
basedir,
basedir + '/extensions/sd-cn-animation/scripts',
basedir + '/extensions/SD-CN-Animation/scripts'
]
for scripts_path_fix in paths_to_ensure:
if not scripts_path_fix in sys.path:
sys.path.extend([scripts_path_fix])
import math
import os
import sys
import traceback
import numpy as np
from PIL import Image
from modules import devices, sd_samplers
from modules import shared, sd_hijack, lowvram
from modules.shared import devices
import modules.shared as shared
import gc
import cv2
import gradio as gr
import time
import skimage
import datetime
from core.flow_utils import RAFT_estimate_flow, RAFT_clear_memory, compute_diff_map
from core import utils
class sdcn_anim_tmp:
prepear_counter = 0
process_counter = 0
input_video = None
output_video = None
curr_frame = None
prev_frame = None
prev_frame_styled = None
prev_frame_alpha_mask = None
fps = None
total_frames = None
prepared_frames = None
prepared_next_flows = None
prepared_prev_flows = None
frames_prepared = False
def read_frame_from_video():
# Reading video file
if sdcn_anim_tmp.input_video.isOpened():
ret, cur_frame = sdcn_anim_tmp.input_video.read()
if cur_frame is not None:
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
else:
cur_frame = None
sdcn_anim_tmp.input_video.release()
return cur_frame
def get_cur_stat():
stat = f'Frames prepared: {sdcn_anim_tmp.prepear_counter + 1} / {sdcn_anim_tmp.total_frames}; '
stat += f'Frames processed: {sdcn_anim_tmp.process_counter + 1} / {sdcn_anim_tmp.total_frames}; '
return stat
def clear_memory_from_sd():
if shared.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
try:
lowvram.send_everything_to_cpu()
except Exception as e:
...
del shared.sd_model
shared.sd_model = None
gc.collect()
devices.torch_gc()
def start_process(*args):
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
sdcn_anim_tmp.process_counter = 0
sdcn_anim_tmp.prepear_counter = 0
# Open the input video file
sdcn_anim_tmp.input_video = cv2.VideoCapture(args_dict['file'].name)
# Get useful info from the source video
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
sdcn_anim_tmp.output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), sdcn_anim_tmp.fps, (args_dict['width'], args_dict['height']))
curr_frame = read_frame_from_video()
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
sdcn_anim_tmp.prepared_frames = np.zeros((11, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
sdcn_anim_tmp.prepared_next_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_frames[0] = curr_frame
args_dict['init_img'] = Image.fromarray(curr_frame)
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
#print('Processed frame ', 0)
sdcn_anim_tmp.curr_frame = curr_frame
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
utils.shared.is_interrupted = False
yield get_cur_stat(), sdcn_anim_tmp.curr_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
try:
for step in range((sdcn_anim_tmp.total_frames-1) * 2):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
occlusion_mask = None
prev_frame = None
curr_frame = sdcn_anim_tmp.curr_frame
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
prepare_steps = 10
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
#clear_memory_from_sd()
device = devices.get_optimal_device()
curr_frame = read_frame_from_video()
if curr_frame is not None:
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
prev_frame = sdcn_anim_tmp.prev_frame.copy()
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
cn = sdcn_anim_tmp.prepear_counter % 10
if sdcn_anim_tmp.prepear_counter % 10 == 0:
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
#print('Prepared frame ', cn+1)
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prepear_counter += 1
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
curr_frame is None:
# Remove RAFT from memory
RAFT_clear_memory()
sdcn_anim_tmp.frames_prepared = True
else:
# process frame
sdcn_anim_tmp.frames_prepared = False
cn = sdcn_anim_tmp.process_counter % 10
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
# process current frame
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['seed'] = -1
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
warped_styled_frame_ = warped_styled_frame.copy()
if sdcn_anim_tmp.process_counter > 0:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_mask = np.clip(alpha_mask, 0, 1)
fl_w, fl_h = prev_flow.shape[:2]
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
# and at the higher part it making parts changed strongly to do it with less flickering.
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# normalizing the colors
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
processed_frame = processed_frame * 0.9 + curr_frame * 0.1
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
args_dict['seed'] = 8888
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
# Write the frame to the output video
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
sdcn_anim_tmp.output_video.write(frame_out)
sdcn_anim_tmp.process_counter += 1
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
sdcn_anim_tmp.prev_frame = None
#print(f'\nEND OF STEP {step}, {sdcn_anim_tmp.prepear_counter}, {sdcn_anim_tmp.process_counter}')
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
except:
pass
RAFT_clear_memory()
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
curr_frame = gr.Image.update()
occlusion_mask = gr.Image.update()
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)

View File

@ -1,476 +0,0 @@
import sys, os
basedirs = [os.getcwd()]
for basedir in basedirs:
paths_to_ensure = [
basedir,
basedir + '/extensions/sd-cn-animation/scripts',
basedir + '/extensions/SD-CN-Animation/scripts'
]
for scripts_path_fix in paths_to_ensure:
if not scripts_path_fix in sys.path:
sys.path.extend([scripts_path_fix])
import math
import os
import sys
import traceback
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices, sd_samplers, img2img
from modules import shared, sd_hijack, lowvram
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, devices, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
import gc
import cv2
import gradio as gr
import time
import skimage
import datetime
from flow_utils import RAFT_estimate_flow, RAFT_clear_memory, compute_diff_map
from types import SimpleNamespace
class sdcn_anim_tmp:
prepear_counter = 0
process_counter = 0
input_video = None
output_video = None
curr_frame = None
prev_frame = None
prev_frame_styled = None
prev_frame_alpha_mask = None
fps = None
total_frames = None
prepared_frames = None
prepared_next_flows = None
prepared_prev_flows = None
frames_prepared = False
def read_frame_from_video():
# Reading video file
if sdcn_anim_tmp.input_video.isOpened():
ret, cur_frame = sdcn_anim_tmp.input_video.read()
if cur_frame is not None:
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
else:
cur_frame = None
sdcn_anim_tmp.input_video.release()
return cur_frame
def get_cur_stat():
stat = f'Frames prepared: {sdcn_anim_tmp.prepear_counter + 1} / {sdcn_anim_tmp.total_frames}; '
stat += f'Frames processed: {sdcn_anim_tmp.process_counter + 1} / {sdcn_anim_tmp.total_frames}; '
return stat
def clear_memory_from_sd():
if shared.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
try:
lowvram.send_everything_to_cpu()
except Exception as e:
...
del shared.sd_model
shared.sd_model = None
gc.collect()
devices.torch_gc()
def get_device():
device=devices.get_optimal_device()
#print('device',device)
return device
def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
args_list = ['id_task', 'mode', 'prompt', 'negative_prompt', 'prompt_styles', 'init_video', 'sketch', 'init_img_with_mask', 'inpaint_color_sketch', 'inpaint_color_sketch_orig', 'init_img_inpaint', 'init_mask_inpaint', 'steps', 'sampler_index', 'mask_blur', 'mask_alpha', 'inpainting_fill', 'restore_faces', 'tiling', 'n_iter', 'batch_size', 'cfg_scale', 'image_cfg_scale', 'denoising_strength', 'fix_frame_strength', 'seed', 'subseed', 'subseed_strength', 'seed_resize_from_h', 'seed_resize_from_w', 'seed_enable_extras', 'height', 'width', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'img2img_batch_input_dir', 'img2img_batch_output_dir', 'img2img_batch_inpaint_mask_dir', 'override_settings_texts']
# set default values for params that were not specified
args_dict = {
'mode': 0,
'prompt': '',
'negative_prompt': '',
'prompt_styles': [],
'init_video': None, # Always required
'steps': 15,
'sampler_index': 0, # 'Euler a'
'mask_blur': 0,
'inpainting_fill': 1, # original
'restore_faces': False,
'tiling': False,
'n_iter': 1,
'batch_size': 1,
'cfg_scale': 5.5,
'image_cfg_scale': 1.5,
'denoising_strength': 0.75,
'fix_frame_strength': 0.15,
'seed': -1,
'subseed': -1,
'subseed_strength': 0,
'seed_resize_from_h': 512,
'seed_resize_from_w': 512,
'seed_enable_extras': False,
'height': 512,
'width': 512,
'resize_mode': 1,
'inpaint_full_res': True,
'inpaint_full_res_padding': 0,
}
args = list(args)
for i in range(len(args_list)):
if (args[i] is None) and (args_list[i] in args_dict):
args[i] = args_dict[args_list[i]]
else:
args_dict[args_list[i]] = args[i]
args_dict['script_inputs'] = args[len(args_list):]
return args_dict, args
# TODO: Refactor all the code below
def start_process(*args):
args_dict, args_list = args_to_dict(*args)
sdcn_anim_tmp.process_counter = 0
sdcn_anim_tmp.prepear_counter = 0
# Open the input video file
sdcn_anim_tmp.input_video = cv2.VideoCapture(args_dict['init_video'].name)
# Get useful info from the source video
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
sdcn_anim_tmp.output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), sdcn_anim_tmp.fps, (args_dict['width'], args_dict['height']))
curr_frame = read_frame_from_video()
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
sdcn_anim_tmp.prepared_frames = np.zeros((11, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
sdcn_anim_tmp.prepared_next_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_frames[0] = curr_frame
args_dict['init_img'] = Image.fromarray(curr_frame)
#args_list[5] = Image.fromarray(curr_frame)
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
#print('Processed frame ', 0)
sdcn_anim_tmp.curr_frame = curr_frame
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
yield get_cur_stat(), sdcn_anim_tmp.curr_frame, None, None, processed_frame, ''
for step in range((sdcn_anim_tmp.total_frames-1) * 2):
args_dict, args_list = args_to_dict(*args)
occlusion_mask = None
prev_frame = None
curr_frame = sdcn_anim_tmp.curr_frame
warped_styled_frame = gr.Image.update()
processed_frame = gr.Image.update()
prepare_steps = 10
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
#clear_memory_from_sd()
device = get_device()
curr_frame = read_frame_from_video()
if curr_frame is not None:
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
prev_frame = sdcn_anim_tmp.prev_frame.copy()
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
cn = sdcn_anim_tmp.prepear_counter % 10
if sdcn_anim_tmp.prepear_counter % 10 == 0:
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
#print('Prepared frame ', cn+1)
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prepear_counter += 1
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
curr_frame is None:
# Remove RAFT from memory
RAFT_clear_memory()
sdcn_anim_tmp.frames_prepared = True
else:
# process frame
sdcn_anim_tmp.frames_prepared = False
cn = sdcn_anim_tmp.process_counter % 10
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
# process current frame
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['seed'] = -1
#args_list[5] = Image.fromarray(curr_frame)
#args_list[24] = -1
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
if sdcn_anim_tmp.process_counter > 0:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_mask = np.clip(alpha_mask, 0, 1)
fl_w, fl_h = prev_flow.shape[:2]
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
# and at the higher part it making parts changed strongly to do it with less flickering.
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# normalizing the colors
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
processed_frame = processed_frame * 0.9 + curr_frame * 0.1
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
args_dict['seed'] = 8888
#args_list[5] = Image.fromarray(processed_frame)
#args_list[23] = 0.15
#args_list[24] = 8888
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
warped_styled_frame = np.clip(warped_styled_frame, 0, 255).astype(np.uint8)
# Write the frame to the output video
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
sdcn_anim_tmp.output_video.write(frame_out)
sdcn_anim_tmp.process_counter += 1
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
sdcn_anim_tmp.prev_frame = None
#print(f'\nEND OF STEP {step}, {sdcn_anim_tmp.prepear_counter}, {sdcn_anim_tmp.process_counter}')
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame, processed_frame, ''
#sdcn_anim_tmp.input_video.release()
#sdcn_anim_tmp.output_video.release()
return get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame, processed_frame, ''
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p)
#images = shared.listfiles(input_dir)
images = [input_img]
is_inpaint_batch = False
#if inpaint_mask_dir:
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
# is_inpaint_batch = len(inpaint_masks) > 0
#if is_inpaint_batch:
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True
p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
generated_images = []
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
if state.interrupted:
break
img = image #Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
#if is_inpaint_batch:
# # try to find corresponding mask for an image using simple filename matching
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# # if not found use first one ("same mask for all images" use-case)
# if not mask_image_path in inpaint_masks:
# mask_image_path = inpaint_masks[0]
# mask_image = Image.open(mask_image_path)
# p.image_mask = mask_image
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
generated_images.append(proc.images[0])
#for n, processed_image in enumerate(proc.images):
# filename = os.path.basename(image)
# if n > 0:
# left, right = os.path.splitext(filename)
# filename = f"{left}-{n}{right}"
# if not save_normally:
# os.makedirs(output_dir, exist_ok=True)
# if processed_image.mode == 'RGBA':
# processed_image = processed_image.convert("RGB")
# processed_image.save(os.path.join(output_dir, filename))
return generated_images
# id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles: list, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args
def img2img(args_dict):
args = SimpleNamespace(**args_dict)
print('override_settings:', args.override_settings_texts)
override_settings = create_override_settings_dict(args.override_settings_texts)
is_batch = args.mode == 5
if args.mode == 0: # img2img
image = args.init_img.convert("RGB")
mask = None
elif args.mode == 1: # img2img sketch
image = args.sketch.convert("RGB")
mask = None
elif args.mode == 2: # inpaint
image, mask = args.init_img_with_mask["image"], args.init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert("RGB")
elif args.mode == 3: # inpaint sketch
image = args.inpaint_color_sketch
orig = args.inpaint_color_sketch_orig or args.inpaint_color_sketch
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - args.mask_alpha / 100)
blur = ImageFilter.GaussianBlur(args.mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif args.mode == 4: # inpaint upload mask
image = args.init_img_inpaint
mask = args.init_mask_inpaint
else:
image = None
mask = None
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= args.denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
styles=args.prompt_styles,
seed=args.seed,
subseed=args.subseed,
subseed_strength=args.subseed_strength,
seed_resize_from_h=args.seed_resize_from_h,
seed_resize_from_w=args.seed_resize_from_w,
seed_enable_extras=args.seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[args.sampler_index].name,
batch_size=args.batch_size,
n_iter=args.n_iter,
steps=args.steps,
cfg_scale=args.cfg_scale,
width=args.width,
height=args.height,
restore_faces=args.restore_faces,
tiling=args.tiling,
init_images=[image],
mask=mask,
mask_blur=args.mask_blur,
inpainting_fill=args.inpainting_fill,
resize_mode=args.resize_mode,
denoising_strength=args.denoising_strength,
image_cfg_scale=args.image_cfg_scale,
inpaint_full_res=args.inpaint_full_res,
inpaint_full_res_padding=args.inpaint_full_res_padding,
inpainting_mask_invert=args.inpainting_mask_invert,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args.script_inputs
#if shared.cmd_opts.enable_console_prompts:
# print(f"\nimg2img: {args.prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = args.mask_blur
'''
if is_batch:
...
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
# process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args.script_inputs)
# processed = Processed(p, [], p.seed, "")
else:
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
'''
generated_images = process_img(p, image, None, '', args.script_inputs)
processed = Processed(p, [], p.seed, "")
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
#if opts.samples_log_stdout:
# print(generation_info_js)
#if opts.do_not_show_images:
# processed.images = []
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)