added
parent
0cb020b157
commit
4cb70384f2
127
compute_flow.py
127
compute_flow.py
|
|
@ -1,75 +1,92 @@
|
|||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from flow_utils import RAFT_estimate_flow
|
||||
import h5py
|
||||
#RAFT dependencies
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder
|
||||
|
||||
def main(args):
|
||||
W, H = args.width, args.height
|
||||
# Open the input video file
|
||||
input_video = cv2.VideoCapture(args.input_video)
|
||||
RAFT_model = None
|
||||
def RAFT_estimate_flow(frame1, frame2, frame1_bg_removed, frame2_bg_removed, device='cuda', subtract_background=True):
|
||||
global RAFT_model
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': 'RAFT/models/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
# Get useful info from the source video
|
||||
fps = int(input_video.get(cv2.CAP_PROP_FPS))
|
||||
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
prev_frame = None
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
# create an empty HDF5 file
|
||||
with h5py.File(args.output_file, 'w') as f: pass
|
||||
with torch.no_grad():
|
||||
if subtract_background:
|
||||
frame1_torch = torch.from_numpy(frame1_bg_removed).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2_bg_removed).permute(2, 0, 1).float()[None].to(device)
|
||||
else:
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
# open the file for writing a flow maps into it
|
||||
with h5py.File(args.output_file, 'a') as f:
|
||||
flow_maps = f.create_dataset('flow_maps', shape=(0, 2, H, W, 2), maxshape=(None, 2, H, W, 2), dtype=np.float16)
|
||||
|
||||
for ind in tqdm(range(total_frames)):
|
||||
# Read the next frame from the input video
|
||||
if not input_video.isOpened(): break
|
||||
ret, cur_frame = input_video.read()
|
||||
if not ret: break
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
cur_frame = cv2.resize(cur_frame, (W, H))
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
if prev_frame is not None:
|
||||
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, cur_frame, subtract_background=args.remove_background)
|
||||
next_flow = next_flow[0].permute(1,2,0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
# write data into a file
|
||||
flow_maps.resize(ind, axis=0)
|
||||
flow_maps[ind-1, 0] = next_flow
|
||||
flow_maps[ind-1, 1] = prev_flow
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = np.clip(occlusion_mask * 0.2 * 255, 0, 255).astype(np.uint8)
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis = -1)
|
||||
|
||||
if args.visualize:
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
if args.remove_background:
|
||||
img_show = cv2.hconcat([cur_frame, frame2_bg_removed, occlusion_mask])
|
||||
else:
|
||||
img_show = cv2.hconcat([cur_frame, occlusion_mask])
|
||||
cv2.imshow('Out img', img_show)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
return next_flow, prev_flow, occlusion_mask
|
||||
|
||||
prev_frame = cur_frame.copy()
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, sigma=5):
|
||||
h, w = cur_frame.shape[:2]
|
||||
|
||||
# Release the input and output video files
|
||||
input_video.release()
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
|
||||
# Close all windows
|
||||
if args.visualize: cv2.destroyAllWindows()
|
||||
flow_map = -next_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(w)
|
||||
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
|
||||
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
|
||||
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
|
||||
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
|
||||
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
|
||||
parser.add_argument('-rb', '--remove_background', action='store_true', help='Remove background of the image')
|
||||
args = parser.parse_args()
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST)
|
||||
warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST)
|
||||
|
||||
main(args)
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (51, 51), sigma, cv2.BORDER_DEFAULT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
|
|
|||
Loading…
Reference in New Issue