parent
c987f645d6
commit
14534d3174
|
|
@ -18,43 +18,43 @@ class FloweR(nn.Module):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 384 x 384 x 128
|
) # 384 x 384 x 128
|
||||||
|
|
||||||
self.conv_block_2 = nn.Sequential(
|
self.conv_block_2 = nn.Sequential( # x128
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 192 x 192 x 128
|
) # 192 x 192 x 128
|
||||||
|
|
||||||
self.conv_block_3 = nn.Sequential(
|
self.conv_block_3 = nn.Sequential( # x64
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 96 x 96 x 128
|
) # 96 x 96 x 128
|
||||||
|
|
||||||
self.conv_block_4 = nn.Sequential(
|
self.conv_block_4 = nn.Sequential( # x32
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 48 x 48 x 128
|
) # 48 x 48 x 128
|
||||||
|
|
||||||
self.conv_block_5 = nn.Sequential(
|
self.conv_block_5 = nn.Sequential( # x16
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 24 x 24 x 128
|
) # 24 x 24 x 128
|
||||||
|
|
||||||
self.conv_block_6 = nn.Sequential(
|
self.conv_block_6 = nn.Sequential( # x8
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 12 x 12 x 128
|
) # 12 x 12 x 128
|
||||||
|
|
||||||
self.conv_block_7 = nn.Sequential(
|
self.conv_block_7 = nn.Sequential( # x4
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
) # 6 x 6 x 128
|
) # 6 x 6 x 128
|
||||||
|
|
||||||
self.conv_block_8 = nn.Sequential(
|
self.conv_block_8 = nn.Sequential( # x2
|
||||||
nn.AvgPool2d(2),
|
nn.AvgPool2d(2),
|
||||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
|
|
||||||
14
readme.md
14
readme.md
|
|
@ -54,8 +54,12 @@ All examples you can see here are originally generated at 512x512 resolution usi
|
||||||
## Installing the extension
|
## Installing the extension
|
||||||
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
|
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
|
||||||
|
|
||||||
## Last version changes: v0.7
|
## Last version changes: v0.8
|
||||||
* Text to Video mode added to the extension
|
* Better error handling. Fixes an issue when errors may not appear in the console.
|
||||||
* 'Generate' button is now automatically disabled while the video is generated
|
* Fixed an issue with deprecated variables. Should be a resolution of running the extension on other webui forks.
|
||||||
* Added 'Interrupt' button that allows to stop video generation process
|
* Slight improvements in vid2vid processing pipeline.
|
||||||
* Now all necessary models are automatically downloaded. No need for manual preparation.
|
* Video preview added to the UI. It will become available at the end of the processing.
|
||||||
|
* Time elapsed/left indication added.
|
||||||
|
* Fixed an issue with color drifting on some models.
|
||||||
|
* Sampler type and sampling steps settings added to text2video mode.
|
||||||
|
* Added automatic resizing before processing with RAFT and FloweR models.
|
||||||
|
|
@ -28,6 +28,7 @@ from modules.sd_samplers import samplers_for_img2img
|
||||||
from modules.ui import setup_progressbar, create_sampler_and_steps_selection, ordered_ui_categories, create_output_panel
|
from modules.ui import setup_progressbar, create_sampler_and_steps_selection, ordered_ui_categories, create_output_panel
|
||||||
|
|
||||||
from core import vid2vid, txt2vid, utils
|
from core import vid2vid, txt2vid, utils
|
||||||
|
import traceback
|
||||||
|
|
||||||
def V2VArgs():
|
def V2VArgs():
|
||||||
seed = -1
|
seed = -1
|
||||||
|
|
@ -62,15 +63,17 @@ def setup_common_values(mode, d):
|
||||||
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
|
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
|
||||||
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
|
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
#steps = gr.Slider(label='Steps', minimum=1, maximum=100, step=1, value=d.steps, interactive=True)
|
|
||||||
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
|
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
|
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
processing_strength = gr.Slider(label="Processing strength", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
processing_strength = gr.Slider(label="Processing strength", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||||
fix_frame_strength = gr.Slider(label="Fix frame strength", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
fix_frame_strength = gr.Slider(label="Fix frame strength", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||||
|
with gr.Row():
|
||||||
|
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{mode}_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index", interactive=True)
|
||||||
|
steps = gr.Slider(label="Sampling steps", minimum=1, maximum=150, step=1, elem_id=f"{mode}_steps", value=d.steps, interactive=True)
|
||||||
|
|
||||||
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength
|
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength, sampler_index, steps
|
||||||
|
|
||||||
def inputs_ui():
|
def inputs_ui():
|
||||||
v2v_args = SimpleNamespace(**V2VArgs())
|
v2v_args = SimpleNamespace(**V2VArgs())
|
||||||
|
|
@ -83,29 +86,17 @@ def inputs_ui():
|
||||||
gr.HTML('Put your video here')
|
gr.HTML('Put your video here')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
|
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
|
||||||
#init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", image_mode="RGBA")
|
|
||||||
#with gr.Row():
|
|
||||||
# gr.HTML('Alternative: enter the relative (to the webui) path to the file')
|
|
||||||
#with gr.Row():
|
|
||||||
# vid2vid_frames_path = gr.Textbox(label="Input video path", interactive=True, elem_id="vid_to_vid_chosen_path", placeholder='Enter your video path here, or upload in the box above ^')
|
|
||||||
|
|
||||||
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength = setup_common_values('vid2vid', v2v_args)
|
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength, v2v_sampler_index, v2v_steps = setup_common_values('vid2vid', v2v_args)
|
||||||
|
|
||||||
with FormRow(elem_id=f"sampler_selection_v2v"):
|
|
||||||
v2v_sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"v2v_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
|
||||||
v2v_steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"v2v_steps", label="Sampling steps", value=15)
|
|
||||||
|
|
||||||
with FormRow(elem_id="vid2vid_override_settings_row") as row:
|
with FormRow(elem_id="vid2vid_override_settings_row") as row:
|
||||||
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
|
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
|
||||||
|
|
||||||
with FormGroup(elem_id=f"script_container"):
|
with FormGroup(elem_id=f"script_container"):
|
||||||
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
|
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
|
||||||
#with gr.Row():
|
|
||||||
# strength = gr.Slider(label="denoising strength", value=d.strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
|
||||||
# vid2vid_startFrame=gr.Number(label='vid2vid start frame',value=d.vid2vid_startFrame)
|
|
||||||
|
|
||||||
with gr.Tab('txt2vid') as tab_txt2vid:
|
with gr.Tab('txt2vid') as tab_txt2vid:
|
||||||
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength = setup_common_values('txt2vid', t2v_args)
|
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
|
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
|
||||||
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
|
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
|
||||||
|
|
@ -117,12 +108,22 @@ def inputs_ui():
|
||||||
return locals()
|
return locals()
|
||||||
|
|
||||||
def process(*args):
|
def process(*args):
|
||||||
if args[0] == 'vid2vid':
|
msg = 'Done'
|
||||||
yield from vid2vid.start_process(*args)
|
try:
|
||||||
elif args[0] == 'txt2vid':
|
if args[0] == 'vid2vid':
|
||||||
yield from txt2vid.start_process(*args)
|
yield from vid2vid.start_process(*args)
|
||||||
else:
|
elif args[0] == 'txt2vid':
|
||||||
raise Exception(f"Unsupported processing mode: '{args[0]}'")
|
yield from txt2vid.start_process(*args)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported processing mode: '{args[0]}'"
|
||||||
|
raise Exception(msg)
|
||||||
|
except Exception as error:
|
||||||
|
# handle the exception
|
||||||
|
msg = f"An exception occurred while trying to process the frame: {error}"
|
||||||
|
print(msg)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
yield msg, gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Video.update(), gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||||
|
|
||||||
def stop_process(*args):
|
def stop_process(*args):
|
||||||
utils.shared.is_interrupted = True
|
utils.shared.is_interrupted = True
|
||||||
|
|
@ -140,18 +141,6 @@ def on_ui_tabs():
|
||||||
with gr.Column(scale=1, variant='panel'):
|
with gr.Column(scale=1, variant='panel'):
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
components = inputs_ui()
|
components = inputs_ui()
|
||||||
|
|
||||||
#for category in ordered_ui_categories():
|
|
||||||
# if category == "sampler":
|
|
||||||
# steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "vid2vid")
|
|
||||||
|
|
||||||
# elif category == "override_settings":
|
|
||||||
# with FormRow(elem_id="vid2vid_override_settings_row") as row:
|
|
||||||
# override_settings = create_override_settings_dropdown("vid2vid", row)
|
|
||||||
|
|
||||||
# elif category == "scripts":
|
|
||||||
# with FormGroup(elem_id=f"script_container"):
|
|
||||||
# custom_inputs = scripts.scripts_img2img.setup_ui()
|
|
||||||
|
|
||||||
with gr.Column(scale=1, variant='compact'):
|
with gr.Column(scale=1, variant='compact'):
|
||||||
with gr.Row(variant='compact'):
|
with gr.Row(variant='compact'):
|
||||||
|
|
@ -172,7 +161,8 @@ def on_ui_tabs():
|
||||||
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil').style(height=240)
|
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil').style(height=240)
|
||||||
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil').style(height=240)
|
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil').style(height=240)
|
||||||
|
|
||||||
html_log = gr.HTML(elem_id=f'html_log_vid2vid')
|
# html_log = gr.HTML(elem_id=f'html_log_vid2vid')
|
||||||
|
video_preview = gr.Video(interactive=False)
|
||||||
|
|
||||||
with gr.Row(variant='compact'):
|
with gr.Row(variant='compact'):
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
@ -186,9 +176,9 @@ def on_ui_tabs():
|
||||||
img_preview_curr_occl,
|
img_preview_curr_occl,
|
||||||
img_preview_prev_warp,
|
img_preview_prev_warp,
|
||||||
img_preview_processed,
|
img_preview_processed,
|
||||||
html_log,
|
video_preview,
|
||||||
run_button,
|
run_button,
|
||||||
stop_button
|
stop_button,
|
||||||
]
|
]
|
||||||
|
|
||||||
run_button.click(
|
run_button.click(
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,11 @@ def RAFT_clear_memory():
|
||||||
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
||||||
global RAFT_model
|
global RAFT_model
|
||||||
|
|
||||||
|
org_size = frame1.shape[1], frame1.shape[0]
|
||||||
|
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
|
||||||
|
frame1 = cv2.resize(frame1, size)
|
||||||
|
frame2 = cv2.resize(frame2, size)
|
||||||
|
|
||||||
model_path = ph.models_path + '/RAFT/raft-things.pth'
|
model_path = ph.models_path + '/RAFT/raft-things.pth'
|
||||||
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
|
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
|
||||||
|
|
||||||
|
|
@ -67,9 +72,9 @@ def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
||||||
RAFT_model.to(device)
|
RAFT_model.to(device)
|
||||||
RAFT_model.eval()
|
RAFT_model.eval()
|
||||||
|
|
||||||
if subtract_background:
|
#if subtract_background:
|
||||||
frame1 = background_subtractor(frame1, fgbg)
|
# frame1 = background_subtractor(frame1, fgbg)
|
||||||
frame2 = background_subtractor(frame2, fgbg)
|
# frame2 = background_subtractor(frame2, fgbg)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||||
|
|
@ -90,10 +95,10 @@ def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
||||||
|
|
||||||
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
||||||
|
|
||||||
return next_flow, prev_flow, occlusion_mask, frame1, frame2
|
next_flow = cv2.resize(next_flow, org_size)
|
||||||
|
prev_flow = cv2.resize(prev_flow, org_size)
|
||||||
# ... rest of the file ...
|
|
||||||
|
|
||||||
|
return next_flow, prev_flow, occlusion_mask #, frame1, frame2
|
||||||
|
|
||||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
|
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
|
||||||
h, w = cur_frame.shape[:2]
|
h, w = cur_frame.shape[:2]
|
||||||
|
|
@ -144,7 +149,7 @@ def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_sty
|
||||||
#diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
#diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||||
#diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
#diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||||
|
|
||||||
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 3) #, diff_mask_stl * 2
|
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4) #, diff_mask_stl * 2
|
||||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||||
|
|
||||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,10 @@ import skimage
|
||||||
import datetime
|
import datetime
|
||||||
import cv2
|
import cv2
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import time
|
||||||
|
|
||||||
FloweR_model = None
|
FloweR_model = None
|
||||||
DEVICE = 'cuda'
|
DEVICE = 'cpu'
|
||||||
def FloweR_clear_memory():
|
def FloweR_clear_memory():
|
||||||
global FloweR_model
|
global FloweR_model
|
||||||
del FloweR_model
|
del FloweR_model
|
||||||
|
|
@ -39,6 +40,8 @@ def FloweR_clear_memory():
|
||||||
|
|
||||||
def FloweR_load_model(w, h):
|
def FloweR_load_model(w, h):
|
||||||
global DEVICE, FloweR_model
|
global DEVICE, FloweR_model
|
||||||
|
DEVICE = devices.get_optimal_device()
|
||||||
|
|
||||||
model_path = ph.models_path + '/FloweR/FloweR_0.1.1.pth'
|
model_path = ph.models_path + '/FloweR/FloweR_0.1.1.pth'
|
||||||
remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv'
|
remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv'
|
||||||
|
|
||||||
|
|
@ -47,13 +50,16 @@ def FloweR_load_model(w, h):
|
||||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
load_file_from_url(remote_model_path, file_name=model_path)
|
load_file_from_url(remote_model_path, file_name=model_path)
|
||||||
|
|
||||||
|
|
||||||
FloweR_model = FloweR(input_size = (h, w))
|
FloweR_model = FloweR(input_size = (h, w))
|
||||||
FloweR_model.load_state_dict(torch.load(model_path))
|
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
||||||
# Move the model to the device
|
# Move the model to the device
|
||||||
DEVICE = devices.get_optimal_device()
|
|
||||||
FloweR_model = FloweR_model.to(DEVICE)
|
FloweR_model = FloweR_model.to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def start_process(*args):
|
def start_process(*args):
|
||||||
|
processing_start_time = time.time()
|
||||||
args_dict = utils.args_to_dict(*args)
|
args_dict = utils.args_to_dict(*args)
|
||||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||||
|
|
||||||
|
|
@ -69,75 +75,79 @@ def start_process(*args):
|
||||||
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
|
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
|
||||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
stat = f"Frame: 1 / {args_dict['length']}"
|
stat = f"Frame: 1 / {args_dict['length']}; " + utils.get_time_left(1, args_dict['length'], processing_start_time)
|
||||||
utils.shared.is_interrupted = False
|
utils.shared.is_interrupted = False
|
||||||
yield stat, init_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
yield stat, init_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||||
|
|
||||||
FloweR_load_model(args_dict['width'], args_dict['height'])
|
org_size = args_dict['width'], args_dict['height']
|
||||||
|
size = args_dict['width'] // 128 * 128, args_dict['height'] // 128 * 128
|
||||||
|
FloweR_load_model(size[0], size[1])
|
||||||
|
|
||||||
clip_frames = np.zeros((4, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
|
clip_frames = np.zeros((4, size[1], size[0], 3), dtype=np.uint8)
|
||||||
|
|
||||||
prev_frame = init_frame
|
prev_frame = init_frame
|
||||||
try:
|
|
||||||
for ind in range(args_dict['length']):
|
|
||||||
if utils.shared.is_interrupted: break
|
|
||||||
|
|
||||||
args_dict = utils.args_to_dict(*args)
|
for ind in range(args_dict['length'] - 1):
|
||||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
if utils.shared.is_interrupted: break
|
||||||
|
|
||||||
clip_frames = np.roll(clip_frames, -1, axis=0)
|
args_dict = utils.args_to_dict(*args)
|
||||||
clip_frames[-1] = prev_frame
|
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||||
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
clip_frames = np.roll(clip_frames, -1, axis=0)
|
||||||
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
|
clip_frames[-1] = cv2.resize(prev_frame, size)
|
||||||
|
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
||||||
|
|
||||||
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
|
with torch.no_grad():
|
||||||
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
|
||||||
|
|
||||||
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
|
||||||
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
||||||
|
|
||||||
|
pred_flow = cv2.resize(pred_flow, org_size)
|
||||||
|
pred_occl = cv2.resize(pred_occl, org_size)
|
||||||
|
|
||||||
|
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
||||||
|
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
||||||
|
|
||||||
|
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
||||||
|
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
||||||
|
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
flow_map = pred_flow.copy()
|
||||||
|
flow_map[:,:,0] += np.arange(args_dict['width'])
|
||||||
|
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
|
||||||
|
|
||||||
|
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
|
||||||
|
|
||||||
|
curr_frame = warped_frame.copy()
|
||||||
|
|
||||||
|
args_dict['mode'] = 4
|
||||||
|
args_dict['init_img'] = Image.fromarray(curr_frame)
|
||||||
|
args_dict['mask_img'] = Image.fromarray(pred_occl)
|
||||||
|
|
||||||
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||||
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||||
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
processed_frame = np.array(processed_frames[0])
|
||||||
|
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=None)
|
||||||
|
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
flow_map = pred_flow.copy()
|
args_dict['mode'] = 0
|
||||||
flow_map[:,:,0] += np.arange(args_dict['width'])
|
args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||||
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
|
args_dict['mask_img'] = None
|
||||||
|
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||||
|
|
||||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
|
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||||
|
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||||
|
processed_frame = np.array(processed_frames[0])
|
||||||
|
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=None)
|
||||||
|
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
curr_frame = warped_frame.copy()
|
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||||
|
prev_frame = processed_frame.copy()
|
||||||
|
|
||||||
args_dict['init_img'] = Image.fromarray(curr_frame)
|
|
||||||
args_dict['mask_img'] = Image.fromarray(pred_occl)
|
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
|
||||||
|
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||||
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
|
||||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
|
||||||
processed_frame = np.array(processed_frames[0])
|
|
||||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
|
|
||||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
args_dict['mode'] = 0
|
|
||||||
args_dict['init_img'] = Image.fromarray(processed_frame)
|
|
||||||
args_dict['mask_img'] = None
|
|
||||||
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
|
||||||
|
|
||||||
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
|
||||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
|
||||||
processed_frame = np.array(processed_frames[0])
|
|
||||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, multichannel=False, channel_axis=-1)
|
|
||||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
|
||||||
prev_frame = processed_frame.copy()
|
|
||||||
|
|
||||||
stat = f"Frame: {ind + 2} / {args_dict['length']}"
|
|
||||||
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
|
||||||
except: pass
|
|
||||||
|
|
||||||
output_video.release()
|
output_video.release()
|
||||||
FloweR_clear_memory()
|
FloweR_clear_memory()
|
||||||
|
|
@ -147,4 +157,4 @@ def start_process(*args):
|
||||||
warped_styled_frame_ = gr.Image.update()
|
warped_styled_frame_ = gr.Image.update()
|
||||||
processed_frame = gr.Image.update()
|
processed_frame = gr.Image.update()
|
||||||
|
|
||||||
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||||
|
|
@ -7,7 +7,7 @@ def get_component_names():
|
||||||
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
|
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
|
||||||
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
|
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
|
||||||
't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
|
't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
|
||||||
't2v_length', 't2v_fps'
|
'v2v_sampler_index', 'v2v_steps', 't2v_length', 't2v_fps'
|
||||||
]
|
]
|
||||||
|
|
||||||
return components_list
|
return components_list
|
||||||
|
|
@ -114,6 +114,15 @@ def set_CNs_input_image(args_dict, image):
|
||||||
if type(script_input).__name__ == 'UiControlNetUnit':
|
if type(script_input).__name__ == 'UiControlNetUnit':
|
||||||
script_input.batch_images = [image]
|
script_input.batch_images = [image]
|
||||||
|
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def get_time_left(ind, length, processing_start_time):
|
||||||
|
s_passed = int(time.time() - processing_start_time)
|
||||||
|
time_passed = datetime.timedelta(seconds=s_passed)
|
||||||
|
s_left = int(s_passed / ind * (length - ind))
|
||||||
|
time_left = datetime.timedelta(seconds=s_left)
|
||||||
|
return f"Time elapsed: {time_passed}; Time left: {time_left};"
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ def clear_memory_from_sd():
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
def start_process(*args):
|
def start_process(*args):
|
||||||
|
processing_start_time = time.time()
|
||||||
args_dict = utils.args_to_dict(*args)
|
args_dict = utils.args_to_dict(*args)
|
||||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||||
|
|
||||||
|
|
@ -95,6 +96,7 @@ def start_process(*args):
|
||||||
# Get useful info from the source video
|
# Get useful info from the source video
|
||||||
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
|
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
|
||||||
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
loop_iterations = (sdcn_anim_tmp.total_frames-1) * 2
|
||||||
|
|
||||||
# Create an output video file with the same fps, width, and height as the input video
|
# Create an output video file with the same fps, width, and height as the input video
|
||||||
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
||||||
|
|
@ -112,7 +114,7 @@ def start_process(*args):
|
||||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||||
processed_frame = np.array(processed_frames[0])
|
processed_frame = np.array(processed_frames[0])
|
||||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
|
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
#print('Processed frame ', 0)
|
#print('Processed frame ', 0)
|
||||||
|
|
||||||
|
|
@ -120,124 +122,129 @@ def start_process(*args):
|
||||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||||
utils.shared.is_interrupted = False
|
utils.shared.is_interrupted = False
|
||||||
yield get_cur_stat(), sdcn_anim_tmp.curr_frame, None, None, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
stat = get_cur_stat() + utils.get_time_left(1, loop_iterations, processing_start_time)
|
||||||
|
yield stat, sdcn_anim_tmp.curr_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||||
|
|
||||||
try:
|
for step in range(loop_iterations):
|
||||||
for step in range((sdcn_anim_tmp.total_frames-1) * 2):
|
if utils.shared.is_interrupted: break
|
||||||
if utils.shared.is_interrupted: break
|
|
||||||
|
args_dict = utils.args_to_dict(*args)
|
||||||
args_dict = utils.args_to_dict(*args)
|
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
|
||||||
|
|
||||||
occlusion_mask = None
|
occlusion_mask = None
|
||||||
prev_frame = None
|
prev_frame = None
|
||||||
curr_frame = sdcn_anim_tmp.curr_frame
|
curr_frame = sdcn_anim_tmp.curr_frame
|
||||||
warped_styled_frame_ = gr.Image.update()
|
warped_styled_frame_ = gr.Image.update()
|
||||||
processed_frame = gr.Image.update()
|
processed_frame = gr.Image.update()
|
||||||
|
|
||||||
prepare_steps = 10
|
prepare_steps = 10
|
||||||
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
|
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
|
||||||
#clear_memory_from_sd()
|
#clear_memory_from_sd()
|
||||||
device = devices.get_optimal_device()
|
device = devices.get_optimal_device()
|
||||||
|
|
||||||
curr_frame = read_frame_from_video()
|
curr_frame = read_frame_from_video()
|
||||||
if curr_frame is not None:
|
if curr_frame is not None:
|
||||||
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
||||||
prev_frame = sdcn_anim_tmp.prev_frame.copy()
|
prev_frame = sdcn_anim_tmp.prev_frame.copy()
|
||||||
|
|
||||||
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
|
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, curr_frame, subtract_background=False, device=device)
|
||||||
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
|
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
cn = sdcn_anim_tmp.prepear_counter % 10
|
cn = sdcn_anim_tmp.prepear_counter % 10
|
||||||
if sdcn_anim_tmp.prepear_counter % 10 == 0:
|
if sdcn_anim_tmp.prepear_counter % 10 == 0:
|
||||||
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
|
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
|
||||||
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
|
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
|
||||||
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
|
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
|
||||||
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
|
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
|
||||||
#print('Prepared frame ', cn+1)
|
#print('Prepared frame ', cn+1)
|
||||||
|
|
||||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||||
|
|
||||||
sdcn_anim_tmp.prepear_counter += 1
|
sdcn_anim_tmp.prepear_counter += 1
|
||||||
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
|
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
|
||||||
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
|
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
|
||||||
curr_frame is None:
|
curr_frame is None:
|
||||||
# Remove RAFT from memory
|
# Remove RAFT from memory
|
||||||
RAFT_clear_memory()
|
RAFT_clear_memory()
|
||||||
sdcn_anim_tmp.frames_prepared = True
|
sdcn_anim_tmp.frames_prepared = True
|
||||||
else:
|
else:
|
||||||
# process frame
|
# process frame
|
||||||
sdcn_anim_tmp.frames_prepared = False
|
sdcn_anim_tmp.frames_prepared = False
|
||||||
|
|
||||||
cn = sdcn_anim_tmp.process_counter % 10
|
cn = sdcn_anim_tmp.process_counter % 10
|
||||||
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
|
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1]
|
||||||
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
|
prev_frame = sdcn_anim_tmp.prepared_frames[cn]
|
||||||
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
|
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
|
||||||
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
|
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
|
||||||
|
|
||||||
# process current frame
|
### STEP 1
|
||||||
args_dict['init_img'] = Image.fromarray(curr_frame)
|
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
|
||||||
args_dict['seed'] = -1
|
warped_styled_frame_ = warped_styled_frame.copy()
|
||||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
|
||||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
|
||||||
processed_frame = np.array(processed_frames[0])
|
|
||||||
|
|
||||||
|
if sdcn_anim_tmp.process_counter > 0:
|
||||||
|
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
|
||||||
|
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
|
||||||
|
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
|
||||||
|
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||||
|
|
||||||
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
|
fl_w, fl_h = prev_flow.shape[:2]
|
||||||
warped_styled_frame_ = warped_styled_frame.copy()
|
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
|
||||||
|
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
|
||||||
|
|
||||||
if sdcn_anim_tmp.process_counter > 0:
|
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
|
||||||
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
|
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
|
||||||
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
|
|
||||||
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
|
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
|
||||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
# and at the higher part it making parts changed strongly to do it with less flickering.
|
||||||
|
|
||||||
|
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
fl_w, fl_h = prev_flow.shape[:2]
|
# process current frame
|
||||||
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
|
args_dict['mode'] = 4
|
||||||
flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None], 0, 1)
|
init_img = warped_styled_frame * 0.95 + curr_frame * 0.05
|
||||||
|
args_dict['init_img'] = Image.fromarray(np.clip(init_img, 0, 255).astype(np.uint8))
|
||||||
|
args_dict['mask_img'] = Image.fromarray(occlusion_mask)
|
||||||
|
args_dict['seed'] = -1
|
||||||
|
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||||
|
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||||
|
processed_frame = np.array(processed_frames[0])
|
||||||
|
|
||||||
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
|
# normalizing the colors
|
||||||
warped_styled_frame = curr_frame.astype(float) * alpha_mask * flow_mask + warped_styled_frame.astype(float) * (1 - alpha_mask * flow_mask)
|
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||||
|
#processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
||||||
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
|
|
||||||
# and at the higher part it making parts changed strongly to do it with less flickering.
|
#processed_frame = processed_frame * 0.94 + curr_frame * 0.06
|
||||||
|
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||||
|
|
||||||
# normalizing the colors
|
### STEP 2
|
||||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
|
args_dict['mode'] = 0
|
||||||
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||||
|
args_dict['mask_img'] = None
|
||||||
processed_frame = processed_frame * 0.9 + curr_frame * 0.1
|
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
args_dict['seed'] = 8888
|
||||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||||
|
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||||
|
processed_frame = np.array(processed_frames[0])
|
||||||
|
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||||
|
|
||||||
args_dict['init_img'] = Image.fromarray(processed_frame)
|
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
|
||||||
args_dict['seed'] = 8888
|
|
||||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
|
||||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
|
||||||
processed_frame = np.array(processed_frames[0])
|
|
||||||
|
|
||||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
# Write the frame to the output video
|
||||||
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
|
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||||
|
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
|
||||||
|
sdcn_anim_tmp.output_video.write(frame_out)
|
||||||
|
|
||||||
# Write the frame to the output video
|
sdcn_anim_tmp.process_counter += 1
|
||||||
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
|
||||||
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
|
sdcn_anim_tmp.input_video.release()
|
||||||
sdcn_anim_tmp.output_video.write(frame_out)
|
sdcn_anim_tmp.output_video.release()
|
||||||
|
sdcn_anim_tmp.prev_frame = None
|
||||||
|
|
||||||
sdcn_anim_tmp.process_counter += 1
|
stat = get_cur_stat() + utils.get_time_left(step+2, loop_iterations+1, processing_start_time)
|
||||||
if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
|
yield stat, curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||||
sdcn_anim_tmp.input_video.release()
|
|
||||||
sdcn_anim_tmp.output_video.release()
|
|
||||||
sdcn_anim_tmp.prev_frame = None
|
|
||||||
|
|
||||||
#print(f'\nEND OF STEP {step}, {sdcn_anim_tmp.prepear_counter}, {sdcn_anim_tmp.process_counter}')
|
|
||||||
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
RAFT_clear_memory()
|
RAFT_clear_memory()
|
||||||
|
|
||||||
|
|
@ -249,4 +256,4 @@ def start_process(*args):
|
||||||
warped_styled_frame_ = gr.Image.update()
|
warped_styled_frame_ = gr.Image.update()
|
||||||
processed_frame = gr.Image.update()
|
processed_frame = gr.Image.update()
|
||||||
|
|
||||||
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, '', gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||||
Loading…
Reference in New Issue