From 4cb70384f2b6dd4e3445f944d457ef7c33e52f3c Mon Sep 17 00:00:00 2001 From: CaptnSeraph Date: Sun, 23 Apr 2023 11:31:27 +0100 Subject: [PATCH] added --- compute_flow.py | 127 +++++++++++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 55 deletions(-) diff --git a/compute_flow.py b/compute_flow.py index 862e174..8c079ad 100644 --- a/compute_flow.py +++ b/compute_flow.py @@ -1,75 +1,92 @@ -import cv2 -import base64 import numpy as np -from tqdm import tqdm -import os +import cv2 -from flow_utils import RAFT_estimate_flow -import h5py +#RAFT dependencies +import sys +sys.path.append('RAFT/core') +from collections import namedtuple +import torch import argparse +from raft import RAFT +from utils.utils import InputPadder -def main(args): - W, H = args.width, args.height - # Open the input video file - input_video = cv2.VideoCapture(args.input_video) +RAFT_model = None +def RAFT_estimate_flow(frame1, frame2, frame1_bg_removed, frame2_bg_removed, device='cuda', subtract_background=True): + global RAFT_model + if RAFT_model is None: + args = argparse.Namespace(**{ + 'model': 'RAFT/models/raft-things.pth', + 'mixed_precision': True, + 'small': False, + 'alternate_corr': False, + 'path': "" + }) - # Get useful info from the source video - fps = int(input_video.get(cv2.CAP_PROP_FPS)) - total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT)) + RAFT_model = torch.nn.DataParallel(RAFT(args)) + RAFT_model.load_state_dict(torch.load(args.model)) - prev_frame = None + RAFT_model = RAFT_model.module + RAFT_model.to(device) + RAFT_model.eval() - # create an empty HDF5 file - with h5py.File(args.output_file, 'w') as f: pass + with torch.no_grad(): + if subtract_background: + frame1_torch = torch.from_numpy(frame1_bg_removed).permute(2, 0, 1).float()[None].to(device) + frame2_torch = torch.from_numpy(frame2_bg_removed).permute(2, 0, 1).float()[None].to(device) + else: + frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device) + frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device) - # open the file for writing a flow maps into it - with h5py.File(args.output_file, 'a') as f: - flow_maps = f.create_dataset('flow_maps', shape=(0, 2, H, W, 2), maxshape=(None, 2, H, W, 2), dtype=np.float16) - for ind in tqdm(range(total_frames)): - # Read the next frame from the input video - if not input_video.isOpened(): break - ret, cur_frame = input_video.read() - if not ret: break + padder = InputPadder(frame1_torch.shape) + image1, image2 = padder.pad(frame1_torch, frame2_torch) - cur_frame = cv2.resize(cur_frame, (W, H)) + # estimate optical flow + _, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True) + _, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True) - if prev_frame is not None: - next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, cur_frame, subtract_background=args.remove_background) + next_flow = next_flow[0].permute(1,2,0).cpu().numpy() + prev_flow = prev_flow[0].permute(1,2,0).cpu().numpy() - # write data into a file - flow_maps.resize(ind, axis=0) - flow_maps[ind-1, 0] = next_flow - flow_maps[ind-1, 1] = prev_flow + fb_flow = next_flow + prev_flow + fb_norm = np.linalg.norm(fb_flow, axis=2) - occlusion_mask = np.clip(occlusion_mask * 0.2 * 255, 0, 255).astype(np.uint8) + occlusion_mask = fb_norm[..., None].repeat(3, axis = -1) - if args.visualize: - # show the last written frame - useful to catch any issue with the process - if args.remove_background: - img_show = cv2.hconcat([cur_frame, frame2_bg_removed, occlusion_mask]) - else: - img_show = cv2.hconcat([cur_frame, occlusion_mask]) - cv2.imshow('Out img', img_show) - if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing + return next_flow, prev_flow, occlusion_mask - prev_frame = cur_frame.copy() +def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, sigma=5): + h, w = cur_frame.shape[:2] - # Release the input and output video files - input_video.release() + next_flow = cv2.resize(next_flow, (w, h)) + prev_flow = cv2.resize(prev_flow, (w, h)) - # Close all windows - if args.visualize: cv2.destroyAllWindows() + flow_map = -next_flow.copy() + flow_map[:,:,0] += np.arange(w) + flow_map[:,:,1] += np.arange(h)[:,np.newaxis] -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_video', help="Path to input video file", required=True) - parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True) - parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int) - parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int) - parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps') - parser.add_argument('-rb', '--remove_background', action='store_true', help='Remove background of the image') - args = parser.parse_args() + warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST) + warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST) - main(args) + # compute occlusion mask + fb_flow = next_flow + prev_flow + fb_norm = np.linalg.norm(fb_flow, axis=2) + + occlusion_mask = fb_norm[..., None] + + diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255 + diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True) + + diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255 + diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True) + + alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2) + alpha_mask = alpha_mask.repeat(3, axis = -1) + + #alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32)) + alpha_mask = cv2.GaussianBlur(alpha_mask, (51, 51), sigma, cv2.BORDER_DEFAULT) + + alpha_mask = np.clip(alpha_mask, 0, 1) + + return alpha_mask, warped_frame_styled