pull/87/head v0.8
Alexey Borsky 2023-05-11 07:34:29 +03:00
parent c987f645d6
commit 14534d3174
7 changed files with 239 additions and 214 deletions

View File

@ -18,43 +18,43 @@ class FloweR(nn.Module):
nn.ReLU(),
) # 384 x 384 x 128
self.conv_block_2 = nn.Sequential(
self.conv_block_2 = nn.Sequential( # x128
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 192 x 192 x 128
self.conv_block_3 = nn.Sequential(
self.conv_block_3 = nn.Sequential( # x64
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 96 x 96 x 128
self.conv_block_4 = nn.Sequential(
self.conv_block_4 = nn.Sequential( # x32
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 48 x 48 x 128
self.conv_block_5 = nn.Sequential(
self.conv_block_5 = nn.Sequential( # x16
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 24 x 24 x 128
self.conv_block_6 = nn.Sequential(
self.conv_block_6 = nn.Sequential( # x8
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 12 x 12 x 128
self.conv_block_7 = nn.Sequential(
self.conv_block_7 = nn.Sequential( # x4
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 6 x 6 x 128
self.conv_block_8 = nn.Sequential(
self.conv_block_8 = nn.Sequential( # x2
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),

View File

@ -54,8 +54,12 @@ All examples you can see here are originally generated at 512x512 resolution usi
## Installing the extension
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
## Last version changes: v0.7
* Text to Video mode added to the extension
* 'Generate' button is now automatically disabled while the video is generated
* Added 'Interrupt' button that allows to stop video generation process
* Now all necessary models are automatically downloaded. No need for manual preparation.
## Last version changes: v0.8
* Better error handling. Fixes an issue when errors may not appear in the console.
* Fixed an issue with deprecated variables. Should be a resolution of running the extension on other webui forks.
* Slight improvements in vid2vid processing pipeline.
* Video preview added to the UI. It will become available at the end of the processing.
* Time elapsed/left indication added.
* Fixed an issue with color drifting on some models.
* Sampler type and sampling steps settings added to text2video mode.
* Added automatic resizing before processing with RAFT and FloweR models.

View File

@ -28,6 +28,7 @@ from modules.sd_samplers import samplers_for_img2img
from modules.ui import setup_progressbar, create_sampler_and_steps_selection, ordered_ui_categories, create_output_panel
from core import vid2vid, txt2vid, utils
import traceback
def V2VArgs():
seed = -1
@ -62,15 +63,17 @@ def setup_common_values(mode, d):
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
with gr.Row():
#steps = gr.Slider(label='Steps', minimum=1, maximum=100, step=1, value=d.steps, interactive=True)
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
with gr.Row():
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
with gr.Row():
processing_strength = gr.Slider(label="Processing strength", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
fix_frame_strength = gr.Slider(label="Fix frame strength", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
with gr.Row():
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{mode}_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index", interactive=True)
steps = gr.Slider(label="Sampling steps", minimum=1, maximum=150, step=1, elem_id=f"{mode}_steps", value=d.steps, interactive=True)
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength, sampler_index, steps
def inputs_ui():
v2v_args = SimpleNamespace(**V2VArgs())
@ -83,29 +86,17 @@ def inputs_ui():
gr.HTML('Put your video here')
with gr.Row():
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
#init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", image_mode="RGBA")
#with gr.Row():
# gr.HTML('Alternative: enter the relative (to the webui) path to the file')
#with gr.Row():
# vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength = setup_common_values('vid2vid', v2v_args)
with FormRow(elem_id=f"sampler_selection_v2v"):
v2v_sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"v2v_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
v2v_steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"v2v_steps", label="Sampling steps", value=15)
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength, v2v_sampler_index, v2v_steps = setup_common_values('vid2vid', v2v_args)
with FormRow(elem_id="vid2vid_override_settings_row") as row:
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
with FormGroup(elem_id=f"script_container"):
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
#with gr.Row():
# strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
# vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
with gr.Tab('txt2vid') as tab_txt2vid:
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength = setup_common_values('txt2vid', t2v_args)
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args)
with gr.Row():
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
@ -117,12 +108,22 @@ def inputs_ui():
return locals()
def process(*args):
if args[0] == 'vid2vid':
yield from vid2vid.start_process(*args)
elif args[0] == 'txt2vid':
yield from txt2vid.start_process(*args)
else:
raise Exception(f"Unsupported processing mode: '{args[0]}'")
msg = 'Done'
try:
if args[0] == 'vid2vid':
yield from vid2vid.start_process(*args)
elif args[0] == 'txt2vid':
yield from txt2vid.start_process(*args)
else:
msg = f"Unsupported processing mode: '{args[0]}'"
raise Exception(msg)
except Exception as error:
# handle the exception
msg = f"An exception occurred while trying to process the frame: {error}"
print(msg)
traceback.print_exc()
yield msg, gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Video.update(), gr.Button.update(interactive=True), gr.Button.update(interactive=False)
def stop_process(*args):
utils.shared.is_interrupted = True
@ -141,18 +142,6 @@ def on_ui_tabs():
with gr.Tabs():
components = inputs_ui()
#for category in ordered_ui_categories():
# if category == "sampler":
# steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "vid2vid")
# elif category == "override_settings":
# with FormRow(elem_id="vid2vid_override_settings_row") as row:
# override_settings = create_override_settings_dropdown("vid2vid", row)
# elif category == "scripts":
# with FormGroup(elem_id=f"script_container"):
# custom_inputs = scripts.scripts_img2img.setup_ui()
with gr.Column(scale=1, variant='compact'):
with gr.Row(variant='compact'):
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
@ -172,7 +161,8 @@ def on_ui_tabs():
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil').style(height=240)
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil').style(height=240)
html_log = gr.HTML(elem_id=f'html_log_vid2vid')
# html_log = gr.HTML(elem_id=f'html_log_vid2vid')
video_preview = gr.Video(interactive=False)
with gr.Row(variant='compact'):
dummy_component = gr.Label(visible=False)
@ -186,9 +176,9 @@ def on_ui_tabs():
img_preview_curr_occl,
img_preview_prev_warp,
img_preview_processed,
html_log,
video_preview,
run_button,
stop_button
stop_button,
]
run_button.click(

View File

@ -43,6 +43,11 @@ def RAFT_clear_memory():
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
global RAFT_model
org_size = frame1.shape[1], frame1.shape[0]
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
frame1 = cv2.resize(frame1, size)
frame2 = cv2.resize(frame2, size)
model_path = ph.models_path + '/RAFT/raft-things.pth'
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
@ -67,9 +72,9 @@ def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
RAFT_model.to(device)
RAFT_model.eval()
if subtract_background:
frame1 = background_subtractor(frame1, fgbg)
frame2 = background_subtractor(frame2, fgbg)
#if subtract_background:
# frame1 = background_subtractor(frame1, fgbg)
# frame2 = background_subtractor(frame2, fgbg)
with torch.no_grad():
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
@ -90,10 +95,10 @@ def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
return next_flow, prev_flow, occlusion_mask, frame1, frame2
# ... rest of the file ...
next_flow = cv2.resize(next_flow, org_size)
prev_flow = cv2.resize(prev_flow, org_size)
return next_flow, prev_flow, occlusion_mask #, frame1, frame2
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
h, w = cur_frame.shape[:2]
@ -144,7 +149,7 @@ def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_sty
#diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
#diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 3) #, diff_mask_stl * 2
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4) #, diff_mask_stl * 2
alpha_mask = alpha_mask.repeat(3, axis = -1)
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))

View File

@ -27,9 +27,10 @@ import skimage
import datetime
import cv2
import gradio as gr
import time
FloweR_model = None
DEVICE = 'cuda'
DEVICE = 'cpu'
def FloweR_clear_memory():
global FloweR_model
del FloweR_model
@ -39,6 +40,8 @@ def FloweR_clear_memory():
def FloweR_load_model(w, h):
global DEVICE, FloweR_model
DEVICE = devices.get_optimal_device()
model_path = ph.models_path + '/FloweR/FloweR_0.1.1.pth'
remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv'
@ -47,13 +50,16 @@ def FloweR_load_model(w, h):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
FloweR_model = FloweR(input_size = (h, w))
FloweR_model.load_state_dict(torch.load(model_path))
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
# Move the model to the device
DEVICE = devices.get_optimal_device()
FloweR_model = FloweR_model.to(DEVICE)
def start_process(*args):
processing_start_time = time.time()
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
@ -69,75 +75,79 @@ def start_process(*args):
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
stat = f"Frame: 1 / {args_dict['length']}"
stat = f"Frame: 1 / {args_dict['length']}; " + utils.get_time_left(1, args_dict['length'], processing_start_time)
utils.shared.is_interrupted = False
yield stat, init_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
yield stat, init_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
FloweR_load_model(args_dict['width'], args_dict['height'])
org_size = args_dict['width'], args_dict['height']
size = args_dict['width'] // 128 * 128, args_dict['height'] // 128 * 128
FloweR_load_model(size[0], size[1])
clip_frames = np.zeros((4, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
clip_frames = np.zeros((4, size[1], size[0], 3), dtype=np.uint8)
prev_frame = init_frame
try:
for ind in range(args_dict['length']):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
for ind in range(args_dict['length'] - 1):
if utils.shared.is_interrupted: break
clip_frames = np.roll(clip_frames, -1, axis=0)
clip_frames[-1] = prev_frame
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
with torch.no_grad():
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
clip_frames = np.roll(clip_frames, -1, axis=0)
clip_frames[-1] = cv2.resize(prev_frame, size)
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
with torch.no_grad():
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
pred_flow = cv2.resize(pred_flow, org_size)
pred_occl = cv2.resize(pred_occl, org_size)
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
flow_map = pred_flow.copy()
flow_map[:,:,0] += np.arange(args_dict['width'])
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
curr_frame = warped_frame.copy()
args_dict['mode'] = 4
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['mask_img'] = Image.fromarray(pred_occl)
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
args_dict['mode'] = 0
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['mask_img'] = None
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
prev_frame = processed_frame.copy()
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
flow_map = pred_flow.copy()
flow_map[:,:,0] += np.arange(args_dict['width'])
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
curr_frame = warped_frame.copy()
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['mask_img'] = Image.fromarray(pred_occl)
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
args_dict['mode'] = 0
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['mask_img'] = None
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
prev_frame = processed_frame.copy()
stat = f"Frame: {ind + 2} / {args_dict['length']}"
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
except: pass
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
output_video.release()
FloweR_clear_memory()
@ -147,4 +157,4 @@ def start_process(*args):
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)

View File

@ -7,7 +7,7 @@ def get_component_names():
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
't2v_length', 't2v_fps'
'v2v_sampler_index', 'v2v_steps', 't2v_length', 't2v_fps'
]
return components_list
@ -114,6 +114,15 @@ def set_CNs_input_image(args_dict, image):
if type(script_input).__name__ == 'UiControlNetUnit':
script_input.batch_images = [image]
import time
import datetime
def get_time_left(ind, length, processing_start_time):
s_passed = int(time.time() - processing_start_time)
time_passed = datetime.timedelta(seconds=s_passed)
s_left = int(s_passed / ind * (length - ind))
time_left = datetime.timedelta(seconds=s_left)
return f"Time elapsed: {time_passed}; Time left: {time_left};"
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops

View File

@ -83,6 +83,7 @@ def clear_memory_from_sd():
devices.torch_gc()
def start_process(*args):
processing_start_time = time.time()
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
@ -95,6 +96,7 @@ def start_process(*args):
# Get useful info from the source video
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
loop_iterations = (sdcn_anim_tmp.total_frames-1) * 2
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
@ -112,7 +114,7 @@ def start_process(*args):
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
#print('Processed frame ', 0)
@ -120,124 +122,129 @@ def start_process(*args):
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
utils.shared.is_interrupted = False
yield get_cur_stat(), sdcn_anim_tmp.curr_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
stat = get_cur_stat() + utils.get_time_left(1, loop_iterations, processing_start_time)
yield stat, sdcn_anim_tmp.curr_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
try:
for step in range((sdcn_anim_tmp.total_frames-1) * 2):
if utils.shared.is_interrupted: break
for step in range(loop_iterations):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
occlusion_mask = None
prev_frame = None
curr_frame = sdcn_anim_tmp.curr_frame
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
occlusion_mask = None
prev_frame = None
curr_frame = sdcn_anim_tmp.curr_frame
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
prepare_steps = 10
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
#clear_memory_from_sd()
device = devices.get_optimal_device()
prepare_steps = 10
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
#clear_memory_from_sd()
device = devices.get_optimal_device()
curr_frame = read_frame_from_video()
if curr_frame is not None:
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
prev_frame = sdcn_anim_tmp.prev_frame.copy()
curr_frame = read_frame_from_video()
if curr_frame is not None:
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
prev_frame = sdcn_anim_tmp.prev_frame.copy()
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
cn = sdcn_anim_tmp.prepear_counter % 10
if sdcn_anim_tmp.prepear_counter % 10 == 0:
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
#print('Prepared frame ', cn+1)
cn = sdcn_anim_tmp.prepear_counter % 10
if sdcn_anim_tmp.prepear_counter % 10 == 0:
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
#print('Prepared frame ', cn+1)
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prepear_counter += 1
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
curr_frame is None:
# Remove RAFT from memory
RAFT_clear_memory()
sdcn_anim_tmp.frames_prepared = True
else:
# process frame
sdcn_anim_tmp.frames_prepared = False
sdcn_anim_tmp.prepear_counter += 1
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
curr_frame is None:
# Remove RAFT from memory
RAFT_clear_memory()
sdcn_anim_tmp.frames_prepared = True
else:
# process frame
sdcn_anim_tmp.frames_prepared = False
cn = sdcn_anim_tmp.process_counter % 10
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
cn = sdcn_anim_tmp.process_counter % 10
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
# process current frame
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['seed'] = -1
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
### STEP 1
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
warped_styled_frame_ = warped_styled_frame.copy()
if sdcn_anim_tmp.process_counter > 0:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_mask = np.clip(alpha_mask, 0, 1)
fl_w, fl_h = prev_flow.shape[:2]
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
# and at the higher part it making parts changed strongly to do it with less flickering.
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# process current frame
args_dict['mode'] = 4
init_img = warped_styled_frame * 0.95 + curr_frame * 0.05
args_dict['init_img'] = Image.fromarray(np.clip(init_img, 0, 255).astype(np.uint8))
args_dict['mask_img'] = Image.fromarray(occlusion_mask)
args_dict['seed'] = -1
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
# normalizing the colors
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
#processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
#processed_frame = processed_frame * 0.94 + curr_frame * 0.06
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
### STEP 2
args_dict['mode'] = 0
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['mask_img'] = None
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
args_dict['seed'] = 8888
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
warped_styled_frame_ = warped_styled_frame.copy()
# Write the frame to the output video
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
sdcn_anim_tmp.output_video.write(frame_out)
if sdcn_anim_tmp.process_counter > 0:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_mask = np.clip(alpha_mask, 0, 1)
sdcn_anim_tmp.process_counter += 1
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
sdcn_anim_tmp.prev_frame = None
fl_w, fl_h = prev_flow.shape[:2]
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
# and at the higher part it making parts changed strongly to do it with less flickering.
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# normalizing the colors
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
processed_frame = processed_frame * 0.9 + curr_frame * 0.1
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
args_dict['seed'] = 8888
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
# Write the frame to the output video
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
sdcn_anim_tmp.output_video.write(frame_out)
sdcn_anim_tmp.process_counter += 1
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
sdcn_anim_tmp.prev_frame = None
#print(f'\nEND OF STEP {step}, {sdcn_anim_tmp.prepear_counter}, {sdcn_anim_tmp.process_counter}')
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
except:
pass
stat = get_cur_stat() + utils.get_time_left(step+2, loop_iterations+1, processing_start_time)
yield stat, curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
RAFT_clear_memory()
@ -249,4 +256,4 @@ def start_process(*args):
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)