Compare commits
99 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
05e265ff3f | |
|
|
f2709f3990 | |
|
|
2e5e09f3c3 | |
|
|
052fdec082 | |
|
|
4f07b5e6d2 | |
|
|
ab1f1583b7 | |
|
|
48d8348de4 | |
|
|
3ccf7373ad | |
|
|
de1ff473f7 | |
|
|
6a04d70241 | |
|
|
2e257bbfc3 | |
|
|
111711fc7b | |
|
|
8bc0954ac8 | |
|
|
c3c0972c0a | |
|
|
89ba89d949 | |
|
|
ba3c17ef7e | |
|
|
b26628a5be | |
|
|
3c36b8e7c5 | |
|
|
c3e4b42d98 | |
|
|
d97ead6ab8 | |
|
|
00fbf01831 | |
|
|
9e08b4c7d3 | |
|
|
8d63dd5471 | |
|
|
e16f728512 | |
|
|
a35f446b69 | |
|
|
9849e6389e | |
|
|
cd400ea7b1 | |
|
|
5d1b65ee48 | |
|
|
9ab15587d8 | |
|
|
e67bcb9264 | |
|
|
f1d47c954a | |
|
|
4ef0097751 | |
|
|
98fe91ceae | |
|
|
33897b7e2f | |
|
|
027faf1612 | |
|
|
63efd3e0e3 | |
|
|
b9d080ff4e | |
|
|
be9f056657 | |
|
|
eddf1a4c8f | |
|
|
ca5fdb1151 | |
|
|
9c5611355d | |
|
|
85cd721e83 | |
|
|
19ac530bb0 | |
|
|
ea0b5e19fc | |
|
|
46ae16e4cb | |
|
|
14534d3174 | |
|
|
c987f645d6 | |
|
|
f33a508d3f | |
|
|
4adcaf026b | |
|
|
a553de32db | |
|
|
31a8dc71d8 | |
|
|
ef5dca6d99 | |
|
|
eff3046d81 | |
|
|
115dbeacb1 | |
|
|
9b88a8f04e | |
|
|
7e55f4d781 | |
|
|
82d0679a51 | |
|
|
bcfd6f994d | |
|
|
50724e8056 | |
|
|
bb4353a264 | |
|
|
c34cfe2976 | |
|
|
dc2be7ba28 | |
|
|
0cb020b157 | |
|
|
3c69efe7a8 | |
|
|
f46b73b6f4 | |
|
|
f073b25440 | |
|
|
10430c5d0f | |
|
|
3e0de3b84d | |
|
|
12dcbef475 | |
|
|
b112d6b2fc | |
|
|
bccf317b03 | |
|
|
0b7fb7c252 | |
|
|
a8c298b7e4 | |
|
|
6a6ca687a5 | |
|
|
05260bce59 | |
|
|
fac580c3b9 | |
|
|
ae605b4299 | |
|
|
621e18ea56 | |
|
|
7e051bbe13 | |
|
|
cb737361d0 | |
|
|
81c425a429 | |
|
|
22bcebb64e | |
|
|
22d909ac59 | |
|
|
5d2f58151d | |
|
|
13a39359ad | |
|
|
0e8039a6d2 | |
|
|
c78c9ae907 | |
|
|
729909e533 | |
|
|
c0e876f4b5 | |
|
|
453bf74188 | |
|
|
e5fd243585 | |
|
|
085f865658 | |
|
|
87834fc192 | |
|
|
3ce461efcc | |
|
|
8287c45e31 | |
|
|
6dda2aaf9d | |
|
|
e90e7e3690 | |
|
|
b65ad41836 | |
|
|
24f77ed2fc |
|
|
@ -1,3 +1,6 @@
|
|||
__pycache__/
|
||||
out/
|
||||
result.mp4
|
||||
videos/
|
||||
FP_Res/
|
||||
result.mp4
|
||||
*.pth
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
|
||||
# Define the model
|
||||
class FloweR(nn.Module):
|
||||
def __init__(self, input_size = (384, 384), window_size = 4):
|
||||
super(FloweR, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.window_size = window_size
|
||||
|
||||
# 2 channels for optical flow
|
||||
# 1 channel for occlusion mask
|
||||
# 3 channels for next frame prediction
|
||||
self.out_channels = 6
|
||||
|
||||
|
||||
#INPUT: 384 x 384 x 4 * 3
|
||||
|
||||
### DOWNSCALE ###
|
||||
self.conv_block_1 = nn.Sequential(
|
||||
nn.Conv2d(3 * self.window_size, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 384 x 384 x 128
|
||||
|
||||
self.conv_block_2 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 192 x 192 x 128
|
||||
|
||||
self.conv_block_3 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 96 x 96 x 128
|
||||
|
||||
self.conv_block_4 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 48 x 48 x 128
|
||||
|
||||
self.conv_block_5 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 24 x 24 x 128
|
||||
|
||||
self.conv_block_6 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 12 x 12 x 128
|
||||
|
||||
self.conv_block_7 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 6 x 6 x 128
|
||||
|
||||
self.conv_block_8 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 3 x 3 x 128 - 9 input tokens
|
||||
|
||||
### Transformer part ###
|
||||
# To be done
|
||||
|
||||
### UPSCALE ###
|
||||
self.conv_block_9 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 6 x 6 x 128
|
||||
|
||||
self.conv_block_10 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 12 x 12 x 128
|
||||
|
||||
self.conv_block_11 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 24 x 24 x 128
|
||||
|
||||
self.conv_block_12 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 48 x 48 x 128
|
||||
|
||||
self.conv_block_13 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 96 x 96 x 128
|
||||
|
||||
self.conv_block_14 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 192 x 192 x 128
|
||||
|
||||
self.conv_block_15 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 384 x 384 x 128
|
||||
|
||||
self.conv_block_16 = nn.Conv2d(128, self.out_channels, kernel_size=3, stride=1, padding='same')
|
||||
|
||||
def forward(self, input_frames):
|
||||
|
||||
if input_frames.size(1) != self.window_size:
|
||||
raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.')
|
||||
|
||||
h, w = self.input_size
|
||||
# batch, frames, height, width, colors
|
||||
input_frames_permuted = input_frames.permute((0, 1, 4, 2, 3))
|
||||
# batch, frames, colors, height, width
|
||||
|
||||
in_x = input_frames_permuted.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1])
|
||||
|
||||
### DOWNSCALE ###
|
||||
block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128
|
||||
block_2_out = self.conv_block_2(block_1_out) # 192 x 192 x 128
|
||||
block_3_out = self.conv_block_3(block_2_out) # 96 x 96 x 128
|
||||
block_4_out = self.conv_block_4(block_3_out) # 48 x 48 x 128
|
||||
block_5_out = self.conv_block_5(block_4_out) # 24 x 24 x 128
|
||||
block_6_out = self.conv_block_6(block_5_out) # 12 x 12 x 128
|
||||
block_7_out = self.conv_block_7(block_6_out) # 6 x 6 x 128
|
||||
block_8_out = self.conv_block_8(block_7_out) # 3 x 3 x 128
|
||||
|
||||
### UPSCALE ###
|
||||
block_9_out = block_7_out + self.conv_block_9(block_8_out) # 6 x 6 x 128
|
||||
block_10_out = block_6_out + self.conv_block_10(block_9_out) # 12 x 12 x 128
|
||||
block_11_out = block_5_out + self.conv_block_11(block_10_out) # 24 x 24 x 128
|
||||
block_12_out = block_4_out + self.conv_block_12(block_11_out) # 48 x 48 x 128
|
||||
block_13_out = block_3_out + self.conv_block_13(block_12_out) # 96 x 96 x 128
|
||||
block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128
|
||||
block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128
|
||||
|
||||
block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1 + 3)
|
||||
out = block_16_out.reshape(-1, self.out_channels, self.input_size[0], self.input_size[1])
|
||||
|
||||
### for future model training ###
|
||||
device = out.get_device()
|
||||
|
||||
pred_flow = out[:,:2,:,:] * 255 # (-255, 255)
|
||||
pred_occl = (out[:,2:3,:,:] + 1) / 2 # [0, 1]
|
||||
pred_next = out[:,3:6,:,:]
|
||||
|
||||
# Generate sampling grids
|
||||
|
||||
# Create grid to upsample input
|
||||
'''
|
||||
d = torch.linspace(-1, 1, 8)
|
||||
meshx, meshy = torch.meshgrid((d, d))
|
||||
grid = torch.stack((meshy, meshx), 2)
|
||||
grid = grid.unsqueeze(0) '''
|
||||
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid = flow_grid.unsqueeze(0).to(device=device)
|
||||
flow_grid = flow_grid + pred_flow
|
||||
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
# batch, flow_chanels, height, width
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
# batch, height, width, flow_chanels
|
||||
|
||||
previous_frame = input_frames_permuted[:, -1, :, :, :]
|
||||
sampling_mode = "bilinear" if self.training else "nearest"
|
||||
warped_frame = torch.nn.functional.grid_sample(previous_frame, flow_grid, mode=sampling_mode, padding_mode="reflection", align_corners=False)
|
||||
alpha_mask = torch.clip(pred_occl * 10, 0, 1) * 0.04
|
||||
pred_next = torch.clip(pred_next, -1, 1)
|
||||
warped_frame = torch.clip(warped_frame, -1, 1)
|
||||
next_frame = pred_next * alpha_mask + warped_frame * (1 - alpha_mask)
|
||||
|
||||
res = torch.cat((pred_flow / 255, pred_occl * 2 - 1, next_frame), dim=1)
|
||||
|
||||
# batch, channels, height, width
|
||||
res = res.permute((0, 2, 3, 1))
|
||||
# batch, height, width, channels
|
||||
return res
|
||||
6
LICENSE
|
|
@ -1,4 +1,4 @@
|
|||
License
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Alexey Borsky
|
||||
|
||||
|
|
@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
This repository can only be used for personal/research/non-commercial purposes.
|
||||
However, for commercial requests, please contact us directly at
|
||||
borsky.alexey@gmail.com
|
||||
|
|
|
|||
1
RAFT
|
|
@ -1 +0,0 @@
|
|||
Subproject commit aac9dd54726caf2cf81d8661b07663e220c5586d
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2020, princeton-vl
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from RAFT.utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels-1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
def __call__(self, coords):
|
||||
r = self.radius
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||
|
||||
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from RAFT.update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from RAFT.extractor import BasicEncoder, SmallEncoder
|
||||
from RAFT.corr import CorrBlock, AlternateCorrBlock
|
||||
from RAFT.utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class RAFT(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(RAFT, self).__init__()
|
||||
self.args = args
|
||||
|
||||
if args.small:
|
||||
self.hidden_dim = hdim = 96
|
||||
self.context_dim = cdim = 64
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 3
|
||||
|
||||
else:
|
||||
self.hidden_dim = hdim = 128
|
||||
self.context_dim = cdim = 128
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
||||
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
else:
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def initialize_flow(self, img):
|
||||
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||
N, C, H, W = img.shape
|
||||
coords0 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
coords1 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8*H, 8*W)
|
||||
|
||||
|
||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
cnet = self.cnet(image1)
|
||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
|
||||
coords0, coords1 = self.initialize_flow(image1)
|
||||
|
||||
if flow_init is not None:
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
flow_predictions = []
|
||||
for itr in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
corr = corr_fn(coords1) # index correlation volume
|
||||
|
||||
flow = coords1 - coords0
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
|
||||
# F(t+1) = F(t) + \Delta(t)
|
||||
coords1 = coords1 + delta_flow
|
||||
|
||||
# upsample predictions
|
||||
if up_mask is None:
|
||||
flow_up = upflow8(coords1 - coords0)
|
||||
else:
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
if test_mode:
|
||||
return coords1 - coords0, flow_up
|
||||
|
||||
return flow_predictions
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
||||
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
|
||||
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
class SmallMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(SmallMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=96):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
self.encoder = SmallMotionEncoder(args)
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, None, delta_flow
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64*9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
||||
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
""" Photometric augmentation """
|
||||
|
||||
# asymmetric
|
||||
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||
|
||||
# symmetric
|
||||
else:
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||
""" Occlusion augmentation """
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(bounds[0], bounds[1])
|
||||
dy = np.random.randint(bounds[0], bounds[1])
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def spatial_transform(self, img1, img2, flow):
|
||||
# randomly sample scale
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 8) / float(ht),
|
||||
(self.crop_size[1] + 8) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
if np.random.rand() < self.stretch_prob:
|
||||
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
|
||||
scale_x = np.clip(scale_x, min_scale, None)
|
||||
scale_y = np.clip(scale_y, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = flow * [scale_x, scale_y]
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < self.h_flip_prob: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
|
||||
if np.random.rand() < self.v_flip_prob: # v-flip
|
||||
img1 = img1[::-1, :]
|
||||
img2 = img2[::-1, :]
|
||||
flow = flow[::-1, :] * [1.0, -1.0]
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
def __call__(self, img1, img2, flow):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
class SparseFlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2):
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(50, 100)
|
||||
dy = np.random.randint(50, 100)
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
||||
ht, wd = flow.shape[:2]
|
||||
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
coords = np.stack(coords, axis=-1)
|
||||
|
||||
coords = coords.reshape(-1, 2).astype(np.float32)
|
||||
flow = flow.reshape(-1, 2).astype(np.float32)
|
||||
valid = valid.reshape(-1).astype(np.float32)
|
||||
|
||||
coords0 = coords[valid>=1]
|
||||
flow0 = flow[valid>=1]
|
||||
|
||||
ht1 = int(round(ht * fy))
|
||||
wd1 = int(round(wd * fx))
|
||||
|
||||
coords1 = coords0 * [fx, fy]
|
||||
flow1 = flow0 * [fx, fy]
|
||||
|
||||
xx = np.round(coords1[:,0]).astype(np.int32)
|
||||
yy = np.round(coords1[:,1]).astype(np.int32)
|
||||
|
||||
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
||||
xx = xx[v]
|
||||
yy = yy[v]
|
||||
flow1 = flow1[v]
|
||||
|
||||
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
||||
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
||||
|
||||
flow_img[yy, xx] = flow1
|
||||
valid_img[yy, xx] = 1
|
||||
|
||||
return flow_img, valid_img
|
||||
|
||||
def spatial_transform(self, img1, img2, flow, valid):
|
||||
# randomly sample scale
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 1) / float(ht),
|
||||
(self.crop_size[1] + 1) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = np.clip(scale, min_scale, None)
|
||||
scale_y = np.clip(scale, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < 0.5: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
valid = valid[:, ::-1]
|
||||
|
||||
margin_y = 20
|
||||
margin_x = 50
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
||||
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
||||
|
||||
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
return img1, img2, flow, valid
|
||||
|
||||
|
||||
def __call__(self, img1, img2, flow, valid):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
valid = np.ascontiguousarray(valid)
|
||||
|
||||
return img1, img2, flow, valid
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2018 Tom Runia
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to conditions.
|
||||
#
|
||||
# Author: Tom Runia
|
||||
# Date Created: 2018-08-03
|
||||
|
||||
import numpy as np
|
||||
|
||||
def make_colorwheel():
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = np.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
||||
col = col+RY
|
||||
# YG
|
||||
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
||||
colorwheel[col:col+YG, 1] = 255
|
||||
col = col+YG
|
||||
# GC
|
||||
colorwheel[col:col+GC, 1] = 255
|
||||
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
||||
col = col+GC
|
||||
# CB
|
||||
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
||||
colorwheel[col:col+CB, 2] = 255
|
||||
col = col+CB
|
||||
# BM
|
||||
colorwheel[col:col+BM, 2] = 255
|
||||
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
||||
col = col+BM
|
||||
# MR
|
||||
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
||||
colorwheel[col:col+MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u)/np.pi
|
||||
fk = (a+1) / 2*(ncols-1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
for i in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:,i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1-f)*col0 + f*col1
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2-i if convert_to_bgr else i
|
||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
u = flow_uv[:,:,0]
|
||||
v = flow_uv[:,:,1]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from os.path import *
|
||||
import re
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
TAG_CHAR = np.array([202021.25], np.float32)
|
||||
|
||||
def readFlow(fn):
|
||||
""" Read .flo file in Middlebury format"""
|
||||
# Code adapted from:
|
||||
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
||||
|
||||
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
||||
# print 'fn = %s'%(fn)
|
||||
with open(fn, 'rb') as f:
|
||||
magic = np.fromfile(f, np.float32, count=1)
|
||||
if 202021.25 != magic:
|
||||
print('Magic number incorrect. Invalid .flo file')
|
||||
return None
|
||||
else:
|
||||
w = np.fromfile(f, np.int32, count=1)
|
||||
h = np.fromfile(f, np.int32, count=1)
|
||||
# print 'Reading %d x %d flo file\n' % (w, h)
|
||||
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
||||
# Reshape data into 3D array (columns, rows, bands)
|
||||
# The reshape here is for visualization, the original code is (w,h,2)
|
||||
return np.resize(data, (int(h), int(w), 2))
|
||||
|
||||
def readPFM(file):
|
||||
file = open(file, 'rb')
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header == b'PF':
|
||||
color = True
|
||||
elif header == b'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
return data
|
||||
|
||||
def writeFlow(filename,uv,v=None):
|
||||
""" Write optical flow to file.
|
||||
|
||||
If v is None, uv is assumed to contain both u and v channels,
|
||||
stacked in depth.
|
||||
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
||||
"""
|
||||
nBands = 2
|
||||
|
||||
if v is None:
|
||||
assert(uv.ndim == 3)
|
||||
assert(uv.shape[2] == 2)
|
||||
u = uv[:,:,0]
|
||||
v = uv[:,:,1]
|
||||
else:
|
||||
u = uv
|
||||
|
||||
assert(u.shape == v.shape)
|
||||
height,width = u.shape
|
||||
f = open(filename,'wb')
|
||||
# write the header
|
||||
f.write(TAG_CHAR)
|
||||
np.array(width).astype(np.int32).tofile(f)
|
||||
np.array(height).astype(np.int32).tofile(f)
|
||||
# arrange into matrix form
|
||||
tmp = np.zeros((height, width*nBands))
|
||||
tmp[:,np.arange(width)*2] = u
|
||||
tmp[:,np.arange(width)*2 + 1] = v
|
||||
tmp.astype(np.float32).tofile(f)
|
||||
f.close()
|
||||
|
||||
|
||||
def readFlowKITTI(filename):
|
||||
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
||||
flow = flow[:,:,::-1].astype(np.float32)
|
||||
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
||||
flow = (flow - 2**15) / 64.0
|
||||
return flow, valid
|
||||
|
||||
def readDispKITTI(filename):
|
||||
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||
valid = disp > 0.0
|
||||
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
||||
return flow, valid
|
||||
|
||||
|
||||
def writeFlowKITTI(filename, uv):
|
||||
uv = 64.0 * uv + 2**15
|
||||
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
||||
cv2.imwrite(filename, uv[..., ::-1])
|
||||
|
||||
|
||||
def read_gen(file_name, pil=False):
|
||||
ext = splitext(file_name)[-1]
|
||||
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
||||
return Image.open(file_name)
|
||||
elif ext == '.bin' or ext == '.raw':
|
||||
return np.load(file_name)
|
||||
elif ext == '.flo':
|
||||
return readFlow(file_name).astype(np.float32)
|
||||
elif ext == '.pfm':
|
||||
flow = readPFM(file_name).astype(np.float32)
|
||||
if len(flow.shape) == 2:
|
||||
return flow
|
||||
else:
|
||||
return flow[:, :, :-1]
|
||||
return []
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||
else:
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
def unpad(self,x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
dx, dy = flow[0], flow[1]
|
||||
|
||||
ht, wd = dx.shape
|
||||
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
|
||||
x1 = x0 + dx
|
||||
y1 = y0 + dy
|
||||
|
||||
x1 = x1.reshape(-1)
|
||||
y1 = y1.reshape(-1)
|
||||
dx = dx.reshape(-1)
|
||||
dy = dy.reshape(-1)
|
||||
|
||||
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||
x1 = x1[valid]
|
||||
y1 = y1[valid]
|
||||
dx = dx[valid]
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def upflow8(flow, mode='bilinear'):
|
||||
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
|
After Width: | Height: | Size: 902 KiB |
|
After Width: | Height: | Size: 89 KiB |
|
After Width: | Height: | Size: 451 KiB |
|
After Width: | Height: | Size: 1.6 MiB |
|
After Width: | Height: | Size: 3.0 MiB |
|
After Width: | Height: | Size: 3.2 MiB |
|
After Width: | Height: | Size: 3.2 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 1.5 MiB |
|
After Width: | Height: | Size: 865 KiB |
|
|
@ -1,87 +0,0 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
|
||||
#RAFT dependencies
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder
|
||||
|
||||
RAFT_model = None
|
||||
def RAFT_estimate_flow(frame1, frame2, device = 'cuda'):
|
||||
global RAFT_model
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': 'RAFT/models/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1,2,0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis = -1)
|
||||
|
||||
return next_flow, prev_flow, occlusion_mask
|
||||
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
|
||||
h, w = cur_frame.shape[:2]
|
||||
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
|
||||
flow_map = -next_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(w)
|
||||
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST)
|
||||
warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST)
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (51,51), 5, cv2.BORDER_DEFAULT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import launch
|
||||
import os
|
||||
import pkg_resources
|
||||
|
||||
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||
|
||||
with open(req_file) as file:
|
||||
for package in file:
|
||||
try:
|
||||
package = package.strip()
|
||||
if '==' in package:
|
||||
package_name, package_version = package.split('==')
|
||||
installed_version = pkg_resources.get_distribution(package_name).version
|
||||
if installed_version != package_version:
|
||||
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: changing {package_name} version from {installed_version} to {package_version}")
|
||||
elif not launch.is_installed(package):
|
||||
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: {package}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f'Warning: Failed to install {package}.')
|
||||
|
|
@ -36,7 +36,7 @@ def main(args):
|
|||
cur_frame = cv2.resize(cur_frame, (W, H))
|
||||
|
||||
if prev_frame is not None:
|
||||
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, cur_frame)
|
||||
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, cur_frame, subtract_background=args.remove_background)
|
||||
|
||||
# write data into a file
|
||||
flow_maps.resize(ind, axis=0)
|
||||
|
|
@ -47,7 +47,10 @@ def main(args):
|
|||
|
||||
if args.visualize:
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
img_show = cv2.hconcat([cur_frame, occlusion_mask])
|
||||
if args.remove_background:
|
||||
img_show = cv2.hconcat([cur_frame, frame2_bg_removed, occlusion_mask])
|
||||
else:
|
||||
img_show = cv2.hconcat([cur_frame, occlusion_mask])
|
||||
cv2.imshow('Out img', img_show)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
|
|
@ -60,12 +63,13 @@ def main(args):
|
|||
if args.visualize: cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
|
||||
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
|
||||
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
|
||||
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
|
||||
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
|
||||
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
|
||||
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
|
||||
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
|
||||
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
|
||||
parser.add_argument('-rb', '--remove_background', action='store_true', help='Remove background of the image')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
|
||||
# RAFT dependencies
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder
|
||||
|
||||
RAFT_model = None
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
|
||||
|
||||
def background_subtractor(frame, fgbg):
|
||||
fgmask = fgbg.apply(frame)
|
||||
return cv2.bitwise_and(frame, frame, mask=fgmask)
|
||||
|
||||
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
||||
global RAFT_model
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': 'RAFT/models/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
if subtract_background:
|
||||
frame1 = background_subtractor(frame1, fgbg)
|
||||
frame2 = background_subtractor(frame2, fgbg)
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
||||
|
||||
return next_flow, prev_flow, occlusion_mask, frame1, frame2
|
||||
|
||||
# ... rest of the file ...
|
||||
|
||||
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
|
||||
h, w = cur_frame.shape[:2]
|
||||
|
||||
#print(np.amin(next_flow), np.amax(next_flow))
|
||||
#exit()
|
||||
|
||||
|
||||
fl_w, fl_h = next_flow.shape[:2]
|
||||
|
||||
# normalize flow
|
||||
next_flow = next_flow / np.array([fl_h,fl_w])
|
||||
prev_flow = prev_flow / np.array([fl_h,fl_w])
|
||||
|
||||
# remove low value noise (@alexfredo suggestion)
|
||||
next_flow[np.abs(next_flow) < 0.05] = 0
|
||||
prev_flow[np.abs(prev_flow) < 0.05] = 0
|
||||
|
||||
# resize flow
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
|
||||
|
||||
# Generate sampling grids
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
|
||||
flow_grid = flow_grid.unsqueeze(0)
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
|
||||
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
|
||||
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
|
||||
|
||||
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (51,51), 5, cv2.BORDER_REFLECT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
||||
def frames_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def flow_norm(flow): return flow / 255
|
||||
|
||||
def occl_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def flow_renorm(flow): return flow * 255
|
||||
|
||||
def occl_renorm(occl): return (occl + 1) * 127.5
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
# SD-CN-Animation
|
||||
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
|
||||
|
||||
|
||||
### Video to Video Examples:
|
||||
<!--
|
||||
[](https://youtu.be/j-0niEMm6DU)
|
||||
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
|
||||
-->
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/girl_org.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_jc.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_wc.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">Original video</td>
|
||||
<td width=33% align="center">"Jessica Chastain"</td>
|
||||
<td width=33% align="center">"Watercolor painting"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
|
||||
|
||||
### Text to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/flower_1.gif" raw=true></td>
|
||||
<td><img src="examples/bonfire_1.gif" raw=true></td>
|
||||
<td><img src="examples/diamond_4.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of a flower"</td>
|
||||
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
|
||||
<td width=33% align="center">"close up of a diamond laying on the table"</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="examples/macaroni_1.gif" raw=true></td>
|
||||
<td><img src="examples/gold_1.gif" raw=true></td>
|
||||
<td><img src="examples/tree_2.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of macaroni on the plate"</td>
|
||||
<td width=33% align="center">"close up of golden sphere"</td>
|
||||
<td width=33% align="center">"a tree standing in the winter forest"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
|
||||
|
||||
|
||||
|
||||
## Dependencies
|
||||
To install all the necessary dependencies, run this command:
|
||||
```
|
||||
pip install opencv-python opencv-contrib-python numpy tqdm h5py scikit-image
|
||||
```
|
||||
You have to set up the RAFT repository as it described here: https://github.com/princeton-vl/RAFT . Basically it just comes down to running "./download_models.sh" in RAFT folder to download the models.
|
||||
|
||||
|
||||
## Running the scripts
|
||||
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You should also have[sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have the control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly, you have to also allow the API to work with controlNet. To do so, go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings".
|
||||
|
||||
|
||||
### Video To Video
|
||||
#### Step 1.
|
||||
To process the video, first of all you would need to precompute optical flow data before running web-ui with this command:
|
||||
```
|
||||
python3 compute_flow.py -i "path to your video" -o "path to output file with *.h5 format" -v -W width_of_the_flow_map -H height_of_the_flow_map
|
||||
```
|
||||
The main reason to do this step separately is to save precious GPU memory that will be useful to generate better quality images. Choose W and H parameters as high as your GPU can handle with respect to the proportion of original video resolution. Do not worry if it is higher or less then the processing resolution, flow maps will be scaled accordingly at the processing stage. This will generate quite a large file that may take up to a several gigabytes on the drive even for minute long video. If you want to process a long video consider splitting it into several parts beforehand.
|
||||
|
||||
|
||||
#### Step 2.
|
||||
Run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```
|
||||
bash webui.sh --xformers --api
|
||||
```
|
||||
|
||||
|
||||
#### Step 3.
|
||||
Go to the **vid2vid.py** file and change main parameters (INPUT_VIDEO, FLOW_MAPS, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. FLOW_MAPS parameter should contain a path to the flow file that you generated at the first step. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time. Finally run the script with the command:
|
||||
```
|
||||
python3 vid2vid.py
|
||||
```
|
||||
|
||||
|
||||
### Text To Video
|
||||
This method is still in development and works on top of ‘Stable Diffusion’ and 'FloweR' - optical flow reconstruction method that is also in a yearly development stage. Do not expect much from it as it is more of a proof of a concept rather than a complete solution.
|
||||
|
||||
#### Step 1.
|
||||
Download 'FloweR_0.1.pth' model from here: [Google drive link](https://drive.google.com/file/d/1WhzoVIw6Kdg4EjfK9LaTLqFm5dF-IJ7F/view?usp=share_link) and place it in the 'FloweR' folder.
|
||||
|
||||
#### Step 2.
|
||||
Same as with vid2vid case, run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```
|
||||
bash webui.sh --xformers --api
|
||||
```
|
||||
|
||||
#### Step 3.
|
||||
Go to the **txt2vid.py** file and change main parameters (OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. Again, the script is simple so you may change other parameters if you want to. Finally run the script with the command:
|
||||
```
|
||||
python3 txt2vid.py
|
||||
```
|
||||
|
||||
## Last version changes: v0.5
|
||||
* Fixed an issue with the wrong direction of an optical flow applied to an image.
|
||||
* Added text to video mode within txt2vid.py script. Make sure to update new dependencies for this script to work!
|
||||
* Added a threshold for an optical flow before processing the frame to remove white noise that might appear, as it was suggested by [@alexfredo](https://github.com/alexfredo).
|
||||
* Background removal at flow computation stage implemented by [@CaptnSeraph](https://github.com/CaptnSeraph), it should reduce ghosting effect in most of the videos processed with vid2vid script.
|
||||
|
||||
<!--
|
||||
## Last version changes: v0.6
|
||||
* Added separate flag '-rb' for background removal process at the flow computation stage in the compute_flow.py script.
|
||||
* Added flow normalization before rescaling it, so the magnitude of the flow computed correctly at the different resolution.
|
||||
* Less ghosting and color change in vid2vid mode
|
||||
-->
|
||||
|
||||
<!--
|
||||
## Potential improvements
|
||||
There are several ways overall quality of animation may be improved:
|
||||
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
|
||||
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
|
||||
* The quality of flow estimation might be greatly improved with a proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
|
||||
-->
|
||||
## Licence
|
||||
This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact me directly at borsky.alexey@gmail.com
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
import requests
|
||||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
import sys
|
||||
sys.path.append('FloweR/')
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
import torch
|
||||
from model import FloweR
|
||||
from utils import flow_viz
|
||||
|
||||
from flow_utils import *
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
|
||||
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
|
||||
|
||||
PROMPT = "people looking at flying robots. Future. People looking to the sky. Stars in the background. Dramatic light, Cinematic light. Soft lighting, high quality, film grain."
|
||||
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
w,h = 768, 512 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
|
||||
|
||||
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
|
||||
|
||||
PROCESSING_STRENGTH = 0.85
|
||||
FIX_STRENGTH = 0.35
|
||||
|
||||
CFG_SCALE = 5.5
|
||||
|
||||
APPLY_TEMPORALNET = False
|
||||
APPLY_COLOR = False
|
||||
|
||||
VISUALIZE = True
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def to_b64(img):
|
||||
img_cliped = np.clip(img, 0, 255).astype(np.uint8)
|
||||
_, buffer = cv2.imencode('.png', img_cliped)
|
||||
b64img = base64.b64encode(buffer).decode("utf-8")
|
||||
return b64img
|
||||
|
||||
class controlnetRequest():
|
||||
def __init__(self, b64_init_img = None, b64_prev_img = None, b64_color_img = None, ds = 0.35, w=w, h=h, mask = None, seed=-1, mode='img2img'):
|
||||
self.url = f"http://localhost:7860/sdapi/v1/{mode}"
|
||||
self.body = {
|
||||
"init_images": [b64_init_img],
|
||||
"mask": mask,
|
||||
"mask_blur": 0,
|
||||
"inpainting_fill": 1,
|
||||
"inpainting_mask_invert": 0,
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": N_PROMPT,
|
||||
"seed": seed,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 15,
|
||||
"cfg_scale": CFG_SCALE,
|
||||
"denoising_strength": ds,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"restore_faces": False,
|
||||
"eta": 0,
|
||||
"sampler_index": "DPM++ 2S a",
|
||||
"control_net_enabled": True,
|
||||
"alwayson_scripts": {
|
||||
"ControlNet":{"args": []}
|
||||
},
|
||||
}
|
||||
|
||||
if APPLY_TEMPORALNET:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_prev_img,
|
||||
"module": "none",
|
||||
"model": "diff_control_sd15_temporalnet_fp16 [adc6bd97]",
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": False
|
||||
})
|
||||
|
||||
if APPLY_COLOR:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_prev_img,
|
||||
"module": "color",
|
||||
"model": "t2iadapter_color_sd14v1 [8522029d]",
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": False
|
||||
})
|
||||
|
||||
|
||||
def sendRequest(self):
|
||||
# Request to web-ui
|
||||
data_js = requests.post(self.url, json=self.body).json()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
return out_image
|
||||
|
||||
|
||||
|
||||
if VISUALIZE: cv2.namedWindow('Out img')
|
||||
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), 15, (w, h))
|
||||
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
|
||||
|
||||
# Instantiate the model
|
||||
model = FloweR(input_size = (h, w))
|
||||
model.load_state_dict(torch.load('FloweR/FloweR_0.1.1.pth'))
|
||||
# Move the model to the device
|
||||
model = model.to(DEVICE)
|
||||
|
||||
|
||||
init_frame = controlnetRequest(mode='txt2img', ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
output_video.write(init_frame)
|
||||
prev_frame = init_frame
|
||||
|
||||
clip_frames = np.zeros((4, h, w, 3), dtype=np.uint8)
|
||||
|
||||
color_shift = np.zeros((0, 3))
|
||||
color_scale = np.zeros((0, 3))
|
||||
for ind in tqdm(range(450)):
|
||||
clip_frames = np.roll(clip_frames, -1, axis=0)
|
||||
clip_frames[-1] = prev_frame
|
||||
|
||||
clip_frames_torch = frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
||||
|
||||
with torch.no_grad():
|
||||
pred_data = model(clip_frames_torch.unsqueeze(0))[0]
|
||||
|
||||
pred_flow = flow_renorm(pred_data[...,:2]).cpu().numpy()
|
||||
pred_occl = occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
||||
|
||||
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
||||
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
||||
|
||||
|
||||
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
||||
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
||||
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
||||
|
||||
flow_map = pred_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(w)
|
||||
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
|
||||
|
||||
out_image = warped_frame.copy()
|
||||
|
||||
out_image = controlnetRequest(
|
||||
b64_init_img = to_b64(out_image),
|
||||
b64_prev_img = to_b64(prev_frame),
|
||||
b64_color_img = to_b64(warped_frame),
|
||||
mask = to_b64(pred_occl),
|
||||
ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
out_image = controlnetRequest(
|
||||
b64_init_img = to_b64(out_image),
|
||||
b64_prev_img = to_b64(prev_frame),
|
||||
b64_color_img = to_b64(warped_frame),
|
||||
mask = None,
|
||||
ds=FIX_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
# These step is necessary to reduce color drift of the image that some models may cause
|
||||
out_image = skimage.exposure.match_histograms(out_image, init_frame, multichannel=True, channel_axis=-1)
|
||||
|
||||
output_video.write(out_image)
|
||||
if SAVE_FRAMES:
|
||||
if not os.path.isdir('out'): os.makedirs('out')
|
||||
cv2.imwrite(f'out/{ind+1:05d}.png', out_image)
|
||||
|
||||
pred_flow_img = flow_viz.flow_to_image(pred_flow)
|
||||
frames_img = cv2.hconcat(list(clip_frames))
|
||||
data_img = cv2.hconcat([pred_flow_img, pred_occl, warped_frame, out_image])
|
||||
|
||||
cv2.imshow('Out img', cv2.vconcat([frames_img, data_img]))
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
prev_frame = out_image.copy()
|
||||
|
||||
# Release the input and output video files
|
||||
output_video.release()
|
||||
|
||||
# Close all windows
|
||||
if VISUALIZE: cv2.destroyAllWindows()
|
||||
|
|
@ -8,22 +8,25 @@ import os
|
|||
import h5py
|
||||
from flow_utils import compute_diff_map
|
||||
|
||||
INPUT_VIDEO = "input.mp4"
|
||||
FLOW_MAPS = "flow.h5"
|
||||
OUTPUT_VIDEO = "result.mp4"
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
PROMPT = "marble statue"
|
||||
INPUT_VIDEO = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.mp4"
|
||||
FLOW_MAPS = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.h5"
|
||||
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
|
||||
|
||||
PROMPT = "Underwater shot Peter Gabriel with closed eyes in Peter Gabriel's music video. 80's music video. VHS style. Dramatic light, Cinematic light. RAW photo, 8k uhd, dslr, soft lighting, high quality, film grain."
|
||||
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
w,h = 1152, 640 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
|
||||
w,h = 1088, 576 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
|
||||
|
||||
START_FROM_IND = 0 # index of a frame to start a processing from. Might be helpful with long animations where you need to restart the script multiple times
|
||||
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
|
||||
|
||||
PROCESSING_STRENGTH = 0.85
|
||||
PROCESSING_STRENGTH = 0.95
|
||||
BLUR_FIX_STRENGTH = 0.15
|
||||
|
||||
APPLY_HED = False
|
||||
APPLY_CANNY = True
|
||||
APPLY_HED = True
|
||||
APPLY_CANNY = False
|
||||
APPLY_DEPTH = False
|
||||
GUESSMODE = False
|
||||
|
||||
|
|
@ -72,12 +75,12 @@ class controlnetRequest():
|
|||
"input_image": b64_hed_img,
|
||||
"module": "hed",
|
||||
"model": "control_hed-fp16 [13fee50b]",
|
||||
"weight": 0.85,
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.85,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": GUESSMODE
|
||||
})
|
||||
|
||||
|
|
@ -136,10 +139,11 @@ fps = int(input_video.get(cv2.CAP_PROP_FPS))
|
|||
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'MP4V'), fps, (w, h))
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
#init_image = None
|
||||
|
||||
# reading flow maps in a stream manner
|
||||
with h5py.File(FLOW_MAPS, 'r') as f:
|
||||
|
|
@ -172,6 +176,8 @@ with h5py.File(FLOW_MAPS, 'r') as f:
|
|||
|
||||
alpha_img = out_image.copy()
|
||||
out_image_ = out_image.copy()
|
||||
warped_styled = out_image.copy()
|
||||
#init_image = out_image.copy()
|
||||
else:
|
||||
# Resize the frame to proper resolution
|
||||
frame = cv2.resize(cur_frame, (w, h))
|
||||
|
|
@ -188,13 +194,18 @@ with h5py.File(FLOW_MAPS, 'r') as f:
|
|||
alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
|
||||
alpha_img = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# normalizing the colors
|
||||
out_image = skimage.exposure.match_histograms(out_image, frame, multichannel=False, channel_axis=-1)
|
||||
|
||||
out_image = out_image.astype(float) * alpha_mask + warped_styled.astype(float) * (1 - alpha_mask)
|
||||
|
||||
out_image_ = (out_image * 0.65 + warped_styled * 0.35)
|
||||
#out_image = skimage.exposure.match_histograms(out_image, prev_frame, multichannel=True, channel_axis=-1)
|
||||
#out_image_ = (out_image * 0.65 + warped_styled * 0.35)
|
||||
|
||||
|
||||
# Bluring issue fix via additional processing
|
||||
out_image_fixed = controlnetRequest(to_b64(out_image_), to_b64(frame), BLUR_FIX_STRENGTH, w, h, mask = None, seed=8888).sendRequest()
|
||||
out_image_fixed = controlnetRequest(to_b64(out_image), to_b64(frame), BLUR_FIX_STRENGTH, w, h, mask = None, seed=8888).sendRequest()
|
||||
|
||||
|
||||
# Write the frame to the output video
|
||||
frame_out = np.clip(out_image_fixed, 0, 255).astype(np.uint8)
|
||||
|
|
@ -202,8 +213,11 @@ with h5py.File(FLOW_MAPS, 'r') as f:
|
|||
|
||||
if VISUALIZE:
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
img_show = cv2.hconcat([frame_out, alpha_img])
|
||||
cv2.imshow('Out img', img_show)
|
||||
warped_styled = np.clip(warped_styled, 0, 255).astype(np.uint8)
|
||||
|
||||
img_show_top = cv2.hconcat([frame, warped_styled])
|
||||
img_show_bot = cv2.hconcat([frame_out, alpha_img])
|
||||
cv2.imshow('Out img', cv2.vconcat([img_show_top, img_show_bot]))
|
||||
cv2.setWindowTitle("Out img", str(ind+1))
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
125
readme.md
|
|
@ -1,49 +1,92 @@
|
|||
# SD-CN-Animation
|
||||
This script allows you to automate video stylization task using StableDiffusion and ControlNet. It uses a simple optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. Here is an example of a video made with this script:
|
||||
> [!Warning]
|
||||
> This repository is no longer maintained. If you are looking for more modern ComfyUI version of this tool please check out this repository, this might be what you are actually looking for: https://github.com/pxl-pshr/ComfyUI-SD-CN-Animation
|
||||
|
||||
[](https://youtu.be/j-0niEMm6DU)
|
||||
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an occlusion mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
|
||||
|
||||
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
|
||||

|
||||
sd-cn-animation ui preview
|
||||
|
||||
## Dependencies
|
||||
To install all the necessary dependencies, run this command:
|
||||
**In vid2vid mode do not forget to activate ControlNet model to achieve better results. Without it the resulting video might be quite choppy. Do not put any images in CN as the frames would pass automatically from the video.**
|
||||
Here are CN parameters that seem to give the best results so far:
|
||||

|
||||
|
||||
|
||||
### Video to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/girl_org.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_jc.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_wc.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">Original video</td>
|
||||
<td width=33% align="center">"Jessica Chastain"</td>
|
||||
<td width=33% align="center">"Watercolor painting"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
|
||||
|
||||
### Text to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/flower_1.gif" raw=true></td>
|
||||
<td><img src="examples/bonfire_1.gif" raw=true></td>
|
||||
<td><img src="examples/diamond_4.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of a flower"</td>
|
||||
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
|
||||
<td width=33% align="center">"close up of a diamond laying on the table"</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="examples/macaroni_1.gif" raw=true></td>
|
||||
<td><img src="examples/gold_1.gif" raw=true></td>
|
||||
<td><img src="examples/tree_2.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of macaroni on the plate"</td>
|
||||
<td width=33% align="center">"close up of golden sphere"</td>
|
||||
<td width=33% align="center">"a tree standing in the winter forest"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
|
||||
|
||||
## Installing the extension
|
||||
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
|
||||
|
||||
## Known issues
|
||||
* If you see error like this ```IndexError: list index out of range``` try to restart webui, it should fix it. If the issue still prevelent try to uninstall and reinstall scikit-image==0.19.2 with no --no-cache-dir flag like this.
|
||||
```
|
||||
pip install opencv-python opencv-contrib-python numpy tqdm h5py
|
||||
pip uninstall scikit-image
|
||||
pip install scikit-image==0.19.2 --no-cache-dir
|
||||
```
|
||||
You have to set up the RAFT repository as it described here: https://github.com/princeton-vl/RAFT . Basically it just comes down to running "./download_models.sh" in RAFT folder to download the models.
|
||||
* The extension might work incorrectly if 'Apply color correction to img2img results to match original colors.' option is enabled. Make sure to disable it in 'Settings' tab -> 'Stable Diffusion' section.
|
||||
* If you have an error like 'Need to enable queue to use generators.', please update webui to the latest version. Beware that only [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is fully supported.
|
||||
* The extension is not compatible with Macs. If you have a case that extension is working for you or do you know how to make it compatible, please open a new discussion.
|
||||
|
||||
## Running the script
|
||||
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You should also have[sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly, you have to also allow API to work with controlNet. To do so, go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings".
|
||||
## Last version changes: v0.9
|
||||
* Fixed issues #69, #76, #91, #92.
|
||||
* Fixed an issue in vid2vid mode when an occlusion mask computed from the optical flow may include unnecessary parts (where flow is non-zero).
|
||||
* Added 'Extra params' in vid2vid mode for more fine-grain controls of the processing pipeline.
|
||||
* Better default parameters set for vid2vid pipeline.
|
||||
* In txt2vid mode after the first frame is generated the seed is now automatically set to -1 to prevent blurring issues.
|
||||
* Added an option to save resulting frames into a folder alongside the video.
|
||||
* Added ability to export current parameters in a human readable form as a json.
|
||||
* Interpolation mode in the flow-applying stage is set to ‘nearest’ to reduce overtime image blurring.
|
||||
* Added ControlNet to txt2vid mode as well as fixing #86 issue, thanks to [@mariaWitch](https://github.com/mariaWitch)
|
||||
* Fixed a major issue when ConrtolNet used wrong input images. Because of this vid2vid results were way worse than they should be.
|
||||
* Text to video mode now supports video as a guidance for ControlNet. It allows to create much stronger video stylizations.
|
||||
|
||||
### Step 1.
|
||||
To process the video, first of all you would need to precompute optical flow data before running web-ui with this command:
|
||||
```
|
||||
python3 compute_flow.py -i {path to your video} -o {path to output *.h5} -v -W {width of the flow map} -H {height of the flow map}
|
||||
```
|
||||
The main reason to do this step separately is to save precious GPU memory that will be useful to generate better quality images. Choose W and H parameters as high as your GPU can handle with respect to proportion of original video resolution. Do not worry if it higher or less then the processing resolution, flow maps will be scaled accordingly at the processing stage. This will generate quite a large file that may take up to a several gigabytes on the drive even for minute long video. If you want to process a long video consider splitting it into several parts beforehand.
|
||||
|
||||
### Step 2.
|
||||
Run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```
|
||||
bash webui.sh --xformers --api
|
||||
```
|
||||
|
||||
### Step 3.
|
||||
Go to the **vid2vid.py** file and change main parameters (INPUT_VIDEO, FLOW_MAPS, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. FLOW_MAPS parameter should contain a path to the flow file that you generated at the first step. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time. Finally run the script with the command:
|
||||
```
|
||||
python3 vid2vid.py
|
||||
```
|
||||
|
||||
## Last version changes: v0.4
|
||||
* Fixed issue with extreme blur accumulating at the static parts of the video.
|
||||
* The order of processing was changed to achieve the best quality at different domains.
|
||||
* Optical flow computation isolated into a separate script for better GPU memory management. Check out the instruction for a new processing pipeline.
|
||||
|
||||
## Potential improvements
|
||||
There are several ways overall quality of animation may be improved:
|
||||
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
|
||||
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
|
||||
* The quality of flow estimation might be greatly improved with proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
|
||||
|
||||
## Licence
|
||||
This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact me directly at borsky.alexey@gmail.com
|
||||
<!--
|
||||
* ControlNet with preprocessers like "reference_only", "reference_adain", "reference_adain+attn" are not reseted with video frames to have an ability to control style of the video.
|
||||
* Fixed an issue because of witch 'processing_strength' UI parameters does not actually affected denoising strength at the fist processing step.
|
||||
* Fixed issue #112. It will not try to reinstall requirements at every start of webui.
|
||||
* Some improvements in text 2 video method.
|
||||
* Parameters used to generated a video now automatically saved in video's folder.
|
||||
* Added ability to control what frame will be send to CN in text to video mode.
|
||||
-->
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
scikit-image
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
import sys, os
|
||||
|
||||
import gradio as gr
|
||||
import modules
|
||||
from types import SimpleNamespace
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
from modules.ui_components import ToolButton, FormRow, FormGroup
|
||||
from modules.ui import create_override_settings_dropdown
|
||||
import modules.scripts as scripts
|
||||
|
||||
from modules.sd_samplers import samplers_for_img2img
|
||||
|
||||
from scripts.core import vid2vid, txt2vid, utils
|
||||
import traceback
|
||||
|
||||
def V2VArgs():
|
||||
seed = -1
|
||||
width = 1024
|
||||
height = 576
|
||||
cfg_scale = 5.5
|
||||
steps = 15
|
||||
prompt = ""
|
||||
n_prompt = "text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
processing_strength = 0.85
|
||||
fix_frame_strength = 0.15
|
||||
return locals()
|
||||
|
||||
def T2VArgs():
|
||||
seed = -1
|
||||
width = 768
|
||||
height = 512
|
||||
cfg_scale = 5.5
|
||||
steps = 15
|
||||
prompt = ""
|
||||
n_prompt = "((blur, blurr, blurred, blurry, fuzzy, unclear, unfocus, bocca effect)), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
processing_strength = 0.75
|
||||
fix_frame_strength = 0.35
|
||||
return locals()
|
||||
|
||||
def setup_common_values(mode, d):
|
||||
with gr.Row():
|
||||
width = gr.Slider(label='Width', minimum=64, maximum=2048, step=64, value=d.width, interactive=True)
|
||||
height = gr.Slider(label='Height', minimum=64, maximum=2048, step=64, value=d.height, interactive=True)
|
||||
with gr.Row(elem_id=f'{mode}_prompt_toprow'):
|
||||
prompt = gr.Textbox(label='Prompt', lines=3, interactive=True, elem_id=f"{mode}_prompt", placeholder="Enter your prompt here...")
|
||||
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
|
||||
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
|
||||
with gr.Row():
|
||||
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
|
||||
with gr.Row():
|
||||
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
|
||||
with gr.Row():
|
||||
processing_strength = gr.Slider(label="Processing strength (Step 1)", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||
fix_frame_strength = gr.Slider(label="Fix frame strength (Step 2)", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||
with gr.Row():
|
||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{mode}_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index", interactive=True)
|
||||
steps = gr.Slider(label="Sampling steps", minimum=1, maximum=150, step=1, elem_id=f"{mode}_steps", value=d.steps, interactive=True)
|
||||
|
||||
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength, sampler_index, steps
|
||||
|
||||
def inputs_ui():
|
||||
v2v_args = SimpleNamespace(**V2VArgs())
|
||||
t2v_args = SimpleNamespace(**T2VArgs())
|
||||
with gr.Tabs():
|
||||
glo_sdcn_process_mode = gr.State(value='vid2vid')
|
||||
|
||||
with gr.Tab('vid2vid') as tab_vid2vid:
|
||||
with gr.Row():
|
||||
gr.HTML('Input video (each frame will be used as initial image for SD and as input image to CN): *REQUIRED')
|
||||
with gr.Row():
|
||||
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
|
||||
|
||||
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength, v2v_sampler_index, v2v_steps = setup_common_values('vid2vid', v2v_args)
|
||||
|
||||
with gr.Accordion("Extra settings",open=False):
|
||||
gr.HTML('# Occlusion mask params:')
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
v2v_occlusion_mask_blur = gr.Slider(label='Occlusion blur strength', minimum=0, maximum=10, step=0.1, value=3, interactive=True)
|
||||
gr.HTML('')
|
||||
v2v_occlusion_mask_trailing = gr.Checkbox(label="Occlusion trailing", info="Reduce ghosting but adds more flickering to the video", value=True, interactive=True)
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
v2v_occlusion_mask_flow_multiplier = gr.Slider(label='Occlusion flow multiplier', minimum=0, maximum=10, step=0.1, value=5, interactive=True)
|
||||
v2v_occlusion_mask_difo_multiplier = gr.Slider(label='Occlusion diff origin multiplier', minimum=0, maximum=10, step=0.1, value=2, interactive=True)
|
||||
v2v_occlusion_mask_difs_multiplier = gr.Slider(label='Occlusion diff styled multiplier', minimum=0, maximum=10, step=0.1, value=0, interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
gr.HTML('# Step 1 params:')
|
||||
v2v_step_1_seed = gr.Number(label='Seed', value = -1, Interactive = True, precision=0)
|
||||
gr.HTML('<br>')
|
||||
v2v_step_1_blend_alpha = gr.Slider(label='Warped prev frame vs Current frame blend alpha', minimum=0, maximum=1, step=0.1, value=1, interactive=True)
|
||||
v2v_step_1_processing_mode = gr.Radio(["Process full image then blend in occlusions", "Inpaint occlusions"], type="index", \
|
||||
label="Processing mode", value="Process full image then blend in occlusions", interactive=True)
|
||||
|
||||
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
gr.HTML('# Step 2 params:')
|
||||
v2v_step_2_seed = gr.Number(label='Seed', value = 8888, Interactive = True, precision=0)
|
||||
|
||||
with FormRow(elem_id="vid2vid_override_settings_row") as row:
|
||||
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
|
||||
|
||||
with FormGroup(elem_id=f"script_container"):
|
||||
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
|
||||
|
||||
with gr.Tab('txt2vid') as tab_txt2vid:
|
||||
with gr.Row():
|
||||
gr.HTML('Control video (each frame will be used as input image to CN): *NOT REQUIRED')
|
||||
with gr.Row():
|
||||
t2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="tex_to_vid_chosen_file")
|
||||
t2v_init_image = gr.Image(label="Input image", interactive=True, file_count="single", file_types=["image"], elem_id="tex_to_vid_init_image")
|
||||
|
||||
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args)
|
||||
|
||||
with gr.Row():
|
||||
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
|
||||
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
|
||||
|
||||
gr.HTML('<br>')
|
||||
t2v_cn_frame_send = gr.Radio(["None", "Current generated frame", "Previous generated frame", "Current reference video frame"], type="index", \
|
||||
label="What frame should be send to CN?", value="None", interactive=True)
|
||||
|
||||
with FormRow(elem_id="txt2vid_override_settings_row") as row:
|
||||
t2v_override_settings = create_override_settings_dropdown("txt2vid", row)
|
||||
|
||||
with FormGroup(elem_id=f"script_container"):
|
||||
t2v_custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
tab_vid2vid.select(fn=lambda: 'vid2vid', inputs=[], outputs=[glo_sdcn_process_mode])
|
||||
tab_txt2vid.select(fn=lambda: 'txt2vid', inputs=[], outputs=[glo_sdcn_process_mode])
|
||||
|
||||
return locals()
|
||||
|
||||
def process(*args):
|
||||
msg = 'Done'
|
||||
try:
|
||||
if args[0] == 'vid2vid':
|
||||
yield from vid2vid.start_process(*args)
|
||||
elif args[0] == 'txt2vid':
|
||||
yield from txt2vid.start_process(*args)
|
||||
else:
|
||||
msg = f"Unsupported processing mode: '{args[0]}'"
|
||||
raise Exception(msg)
|
||||
except Exception as error:
|
||||
# handle the exception
|
||||
msg = f"An exception occurred while trying to process the frame: {error}"
|
||||
print(msg)
|
||||
traceback.print_exc()
|
||||
|
||||
yield msg, gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Video.update(), gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||
|
||||
def stop_process(*args):
|
||||
utils.shared.is_interrupted = True
|
||||
return gr.Button.update(interactive=False)
|
||||
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as sdcnanim_interface:
|
||||
components = {}
|
||||
|
||||
#dv = SimpleNamespace(**T2VOutputArgs())
|
||||
with gr.Row(elem_id='sdcn-core', equal_height=False, variant='compact'):
|
||||
with gr.Column(scale=1, variant='panel'):
|
||||
#with gr.Tabs():
|
||||
components = inputs_ui()
|
||||
|
||||
with gr.Accordion("Export settings", open=False):
|
||||
export_settings_button = gr.Button('Export', elem_id=f"sdcn_export_settings_button")
|
||||
export_setting_json = gr.Code(value='')
|
||||
|
||||
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
with gr.Row(variant='compact'):
|
||||
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
|
||||
stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False)
|
||||
|
||||
save_frames_check = gr.Checkbox(label="Save frames into a folder nearby a video (check it before running the generation if you also want to save frames separately)", value=True, interactive=True)
|
||||
gr.HTML('<br>')
|
||||
|
||||
with gr.Column(variant="panel"):
|
||||
sp_progress = gr.HTML(elem_id="sp_progress", value="")
|
||||
|
||||
with gr.Row(variant='compact'):
|
||||
img_preview_curr_frame = gr.Image(label='Current frame', elem_id=f"img_preview_curr_frame", type='pil', height=240)
|
||||
img_preview_curr_occl = gr.Image(label='Current occlusion', elem_id=f"img_preview_curr_occl", type='pil', height=240)
|
||||
with gr.Row(variant='compact'):
|
||||
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil', height=240)
|
||||
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil', height=240)
|
||||
|
||||
video_preview = gr.Video(interactive=False)
|
||||
|
||||
with gr.Row(variant='compact'):
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
||||
components['glo_save_frames_check'] = save_frames_check
|
||||
|
||||
# Define parameters for the action methods.
|
||||
utils.shared.v2v_custom_inputs_size = len(components['v2v_custom_inputs'])
|
||||
utils.shared.t2v_custom_inputs_size = len(components['t2v_custom_inputs'])
|
||||
#print('v2v_custom_inputs', len(components['v2v_custom_inputs']), components['v2v_custom_inputs'])
|
||||
#print('t2v_custom_inputs', len(components['t2v_custom_inputs']), components['t2v_custom_inputs'])
|
||||
method_inputs = [components[name] for name in utils.get_component_names()] + components['v2v_custom_inputs'] + components['t2v_custom_inputs']
|
||||
|
||||
method_outputs = [
|
||||
sp_progress,
|
||||
img_preview_curr_frame,
|
||||
img_preview_curr_occl,
|
||||
img_preview_prev_warp,
|
||||
img_preview_processed,
|
||||
video_preview,
|
||||
run_button,
|
||||
stop_button,
|
||||
]
|
||||
|
||||
run_button.click(
|
||||
fn=process, #wrap_gradio_gpu_call(start_process, extra_outputs=[None, '', '']),
|
||||
inputs=method_inputs,
|
||||
outputs=method_outputs,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
stop_button.click(
|
||||
fn=stop_process,
|
||||
outputs=[stop_button],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
export_settings_button.click(
|
||||
fn=utils.export_settings,
|
||||
inputs=method_inputs,
|
||||
outputs=[export_setting_json],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
|
||||
# define queue - required for generators
|
||||
sdcnanim_interface.queue()
|
||||
|
||||
return [(sdcnanim_interface, "SD-CN-Animation", "sd_cn_animation_interface")]
|
||||
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
import sys, os
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from RAFT.raft import RAFT
|
||||
from RAFT.utils.utils import InputPadder
|
||||
|
||||
import modules.paths as ph
|
||||
import gc
|
||||
|
||||
RAFT_model = None
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
|
||||
|
||||
def background_subtractor(frame, fgbg):
|
||||
fgmask = fgbg.apply(frame)
|
||||
return cv2.bitwise_and(frame, frame, mask=fgmask)
|
||||
|
||||
def RAFT_clear_memory():
|
||||
global RAFT_model
|
||||
del RAFT_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
RAFT_model = None
|
||||
|
||||
def RAFT_estimate_flow(frame1, frame2, device='cuda'):
|
||||
global RAFT_model
|
||||
|
||||
org_size = frame1.shape[1], frame1.shape[0]
|
||||
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
|
||||
frame1 = cv2.resize(frame1, size)
|
||||
frame2 = cv2.resize(frame2, size)
|
||||
|
||||
model_path = ph.models_path + '/RAFT/raft-things.pth'
|
||||
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
load_file_from_url(remote_model_path, file_name=model_path)
|
||||
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': ph.models_path + '/RAFT/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
||||
|
||||
next_flow = cv2.resize(next_flow, org_size)
|
||||
prev_flow = cv2.resize(prev_flow, org_size)
|
||||
|
||||
return next_flow, prev_flow, occlusion_mask
|
||||
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, args_dict):
|
||||
h, w = cur_frame.shape[:2]
|
||||
fl_w, fl_h = next_flow.shape[:2]
|
||||
|
||||
# normalize flow
|
||||
next_flow = next_flow / np.array([fl_h,fl_w])
|
||||
prev_flow = prev_flow / np.array([fl_h,fl_w])
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow , axis=2)
|
||||
|
||||
zero_flow_mask = np.clip(1 - np.linalg.norm(prev_flow, axis=-1)[...,None] * 20, 0, 1)
|
||||
diff_mask_flow = fb_norm[..., None] * zero_flow_mask
|
||||
|
||||
# resize flow
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
|
||||
|
||||
# Generate sampling grids
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
|
||||
flow_grid = flow_grid.unsqueeze(0)
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
|
||||
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
|
||||
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
|
||||
|
||||
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum.reduce([diff_mask_flow * args_dict['occlusion_mask_flow_multiplier'] * 10, \
|
||||
diff_mask_org * args_dict['occlusion_mask_difo_multiplier'], \
|
||||
diff_mask_stl * args_dict['occlusion_mask_difs_multiplier']]) #
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
if args_dict['occlusion_mask_blur'] > 0:
|
||||
blur_filter_size = min(w,h) // 15 | 1
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (blur_filter_size, blur_filter_size) , args_dict['occlusion_mask_blur'], cv2.BORDER_REFLECT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
||||
def frames_norm(frame): return frame / 127.5 - 1
|
||||
|
||||
def flow_norm(flow): return flow / 255
|
||||
|
||||
def occl_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def frames_renorm(frame): return (frame + 1) * 127.5
|
||||
|
||||
def flow_renorm(flow): return flow * 255
|
||||
|
||||
def occl_renorm(occl): return (occl + 1) * 127.5
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
import sys, os
|
||||
|
||||
import torch
|
||||
import gc
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import modules.paths as ph
|
||||
from modules import devices
|
||||
|
||||
from scripts.core import utils, flow_utils
|
||||
from FloweR.model import FloweR
|
||||
|
||||
import skimage
|
||||
import datetime
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
FloweR_model = None
|
||||
DEVICE = 'cpu'
|
||||
def FloweR_clear_memory():
|
||||
global FloweR_model
|
||||
del FloweR_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
FloweR_model = None
|
||||
|
||||
def FloweR_load_model(w, h):
|
||||
global DEVICE, FloweR_model
|
||||
DEVICE = devices.get_optimal_device()
|
||||
|
||||
model_path = ph.models_path + '/FloweR/FloweR_0.1.2.pth'
|
||||
#remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv' #FloweR_0.1.1.pth
|
||||
remote_model_path = 'https://drive.google.com/uc?id=1-UYsTXkdUkHLgtPK1Y5_7kKzCgzL_Z6o' #FloweR_0.1.2.pth
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
load_file_from_url(remote_model_path, file_name=model_path)
|
||||
|
||||
|
||||
FloweR_model = FloweR(input_size = (h, w))
|
||||
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
||||
# Move the model to the device
|
||||
FloweR_model = FloweR_model.to(DEVICE)
|
||||
FloweR_model.eval()
|
||||
|
||||
def read_frame_from_video(input_video):
|
||||
if input_video is None: return None
|
||||
|
||||
# Reading video file
|
||||
if input_video.isOpened():
|
||||
ret, cur_frame = input_video.read()
|
||||
if cur_frame is not None:
|
||||
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
cur_frame = None
|
||||
input_video.release()
|
||||
input_video = None
|
||||
|
||||
return cur_frame
|
||||
|
||||
def start_process(*args):
|
||||
processing_start_time = time.time()
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||
|
||||
# Open the input video file
|
||||
input_video = None
|
||||
if args_dict['file'] is not None:
|
||||
input_video = cv2.VideoCapture(args_dict['file'].name)
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video_name = f'outputs/sd-cn-animation/txt2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
||||
output_video_folder = os.path.splitext(output_video_name)[0]
|
||||
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
|
||||
|
||||
#if args_dict['save_frames_check']:
|
||||
os.makedirs(output_video_folder, exist_ok=True)
|
||||
|
||||
# Writing to current params to params.json
|
||||
setts_json = utils.export_settings(*args)
|
||||
with open(os.path.join(output_video_folder, "params.json"), "w") as outfile:
|
||||
outfile.write(setts_json)
|
||||
|
||||
curr_frame = None
|
||||
prev_frame = None
|
||||
|
||||
def save_result_to_image(image, ind):
|
||||
if args_dict['save_frames_check']:
|
||||
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
||||
|
||||
def set_cn_frame_input():
|
||||
if args_dict['cn_frame_send'] == 0: # Current generated frame"
|
||||
pass
|
||||
elif args_dict['cn_frame_send'] == 1: # Current generated frame"
|
||||
if curr_frame is not None:
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame), set_references=True)
|
||||
elif args_dict['cn_frame_send'] == 2: # Previous generated frame
|
||||
if prev_frame is not None:
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(prev_frame), set_references=True)
|
||||
elif args_dict['cn_frame_send'] == 3: # Current reference video frame
|
||||
if input_video is not None:
|
||||
curr_video_frame = read_frame_from_video(input_video)
|
||||
curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height']))
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame), set_references=True)
|
||||
else:
|
||||
raise Exception('There is no input video! Set it up first.')
|
||||
else:
|
||||
raise Exception('Incorrect cn_frame_send mode!')
|
||||
|
||||
set_cn_frame_input()
|
||||
|
||||
if args_dict['init_image'] is not None:
|
||||
#resize array to args_dict['width'], args_dict['height']
|
||||
image_array=args_dict['init_image']#this is a numpy array
|
||||
init_frame = np.array(Image.fromarray(image_array).resize((args_dict['width'], args_dict['height'])).convert('RGB'))
|
||||
processed_frame = init_frame.copy()
|
||||
else:
|
||||
processed_frames, _, _, _ = utils.txt2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
init_frame = processed_frame.copy()
|
||||
|
||||
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
|
||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
stat = f"Frame: 1 / {args_dict['length']}; " + utils.get_time_left(1, args_dict['length'], processing_start_time)
|
||||
utils.shared.is_interrupted = False
|
||||
|
||||
save_result_to_image(processed_frame, 1)
|
||||
yield stat, init_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
org_size = args_dict['width'], args_dict['height']
|
||||
size = args_dict['width'] // 128 * 128, args_dict['height'] // 128 * 128
|
||||
FloweR_load_model(size[0], size[1])
|
||||
|
||||
clip_frames = np.zeros((4, size[1], size[0], 3), dtype=np.uint8)
|
||||
|
||||
prev_frame = init_frame
|
||||
|
||||
for ind in range(args_dict['length'] - 1):
|
||||
if utils.shared.is_interrupted: break
|
||||
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||
|
||||
clip_frames = np.roll(clip_frames, -1, axis=0)
|
||||
clip_frames[-1] = cv2.resize(prev_frame[...,:3], size)
|
||||
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
||||
|
||||
with torch.no_grad():
|
||||
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
|
||||
|
||||
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
|
||||
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
||||
pred_next = flow_utils.frames_renorm(pred_data[...,3:6]).cpu().numpy()
|
||||
|
||||
pred_occl = np.clip(pred_occl * 10, 0, 255).astype(np.uint8)
|
||||
pred_next = np.clip(pred_next, 0, 255).astype(np.uint8)
|
||||
|
||||
pred_flow = cv2.resize(pred_flow, org_size)
|
||||
pred_occl = cv2.resize(pred_occl, org_size)
|
||||
pred_next = cv2.resize(pred_next, org_size)
|
||||
|
||||
curr_frame = pred_next.copy()
|
||||
|
||||
'''
|
||||
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
||||
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
||||
|
||||
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
||||
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
||||
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
||||
|
||||
flow_map = pred_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(args_dict['width'])
|
||||
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT_101)
|
||||
alpha_mask = pred_occl / 255.
|
||||
#alpha_mask = np.clip(alpha_mask + np.random.normal(0, 0.4, size = alpha_mask.shape), 0, 1)
|
||||
curr_frame = pred_next.astype(float) * alpha_mask + warped_frame.astype(float) * (1 - alpha_mask)
|
||||
curr_frame = np.clip(curr_frame, 0, 255).astype(np.uint8)
|
||||
#curr_frame = warped_frame.copy()
|
||||
'''
|
||||
|
||||
set_cn_frame_input()
|
||||
|
||||
args_dict['mode'] = 4
|
||||
args_dict['init_img'] = Image.fromarray(pred_next)
|
||||
args_dict['mask_img'] = Image.fromarray(pred_occl)
|
||||
args_dict['seed'] = -1
|
||||
args_dict['denoising_strength'] = args_dict['processing_strength']
|
||||
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
#else:
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
|
||||
args_dict['mode'] = 0
|
||||
args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||
args_dict['mask_img'] = None
|
||||
args_dict['seed'] = -1
|
||||
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||
|
||||
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
#else:
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
|
||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||
prev_frame = processed_frame.copy()
|
||||
|
||||
save_result_to_image(processed_frame, ind + 2)
|
||||
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
|
||||
yield stat, curr_frame, pred_occl, pred_next, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
if input_video is not None: input_video.release()
|
||||
output_video.release()
|
||||
FloweR_clear_memory()
|
||||
|
||||
curr_frame = gr.Image.update()
|
||||
occlusion_mask = gr.Image.update()
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
# print('TOTAL TIME:', int(time.time() - processing_start_time))
|
||||
|
||||
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
class shared:
|
||||
is_interrupted = False
|
||||
v2v_custom_inputs_size = 0
|
||||
t2v_custom_inputs_size = 0
|
||||
|
||||
def get_component_names():
|
||||
components_list = [
|
||||
'glo_sdcn_process_mode',
|
||||
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
|
||||
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
|
||||
'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier',
|
||||
'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed',
|
||||
't2v_file','t2v_init_image', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
|
||||
't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 't2v_cn_frame_send',
|
||||
'glo_save_frames_check'
|
||||
]
|
||||
|
||||
return components_list
|
||||
|
||||
def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
|
||||
args_list = get_component_names()
|
||||
|
||||
# set default values for params that were not specified
|
||||
args_dict = {
|
||||
# video to video params
|
||||
'v2v_mode': 0,
|
||||
'v2v_prompt': '',
|
||||
'v2v_n_prompt': '',
|
||||
'v2v_prompt_styles': [],
|
||||
'v2v_init_video': None, # Always required
|
||||
|
||||
'v2v_steps': 15,
|
||||
'v2v_sampler_index': 0, # 'Euler a'
|
||||
'v2v_mask_blur': 0,
|
||||
|
||||
'v2v_inpainting_fill': 1, # original
|
||||
'v2v_restore_faces': False,
|
||||
'v2v_tiling': False,
|
||||
'v2v_n_iter': 1,
|
||||
'v2v_batch_size': 1,
|
||||
'v2v_cfg_scale': 5.5,
|
||||
'v2v_image_cfg_scale': 1.5,
|
||||
'v2v_denoising_strength': 0.75,
|
||||
'v2v_processing_strength': 0.85,
|
||||
'v2v_fix_frame_strength': 0.15,
|
||||
'v2v_seed': -1,
|
||||
'v2v_subseed': -1,
|
||||
'v2v_subseed_strength': 0,
|
||||
'v2v_seed_resize_from_h': 512,
|
||||
'v2v_seed_resize_from_w': 512,
|
||||
'v2v_seed_enable_extras': False,
|
||||
'v2v_height': 512,
|
||||
'v2v_width': 512,
|
||||
'v2v_resize_mode': 1,
|
||||
'v2v_inpaint_full_res': True,
|
||||
'v2v_inpaint_full_res_padding': 0,
|
||||
'v2v_inpainting_mask_invert': False,
|
||||
|
||||
# text to video params
|
||||
't2v_mode': 4,
|
||||
't2v_prompt': '',
|
||||
't2v_n_prompt': '',
|
||||
't2v_prompt_styles': [],
|
||||
't2v_init_img': None,
|
||||
't2v_mask_img': None,
|
||||
|
||||
't2v_steps': 15,
|
||||
't2v_sampler_index': 0, # 'Euler a'
|
||||
't2v_mask_blur': 0,
|
||||
|
||||
't2v_inpainting_fill': 1, # original
|
||||
't2v_restore_faces': False,
|
||||
't2v_tiling': False,
|
||||
't2v_n_iter': 1,
|
||||
't2v_batch_size': 1,
|
||||
't2v_cfg_scale': 5.5,
|
||||
't2v_image_cfg_scale': 1.5,
|
||||
't2v_denoising_strength': 0.75,
|
||||
't2v_processing_strength': 0.85,
|
||||
't2v_fix_frame_strength': 0.15,
|
||||
't2v_seed': -1,
|
||||
't2v_subseed': -1,
|
||||
't2v_subseed_strength': 0,
|
||||
't2v_seed_resize_from_h': 512,
|
||||
't2v_seed_resize_from_w': 512,
|
||||
't2v_seed_enable_extras': False,
|
||||
't2v_height': 512,
|
||||
't2v_width': 512,
|
||||
't2v_resize_mode': 1,
|
||||
't2v_inpaint_full_res': True,
|
||||
't2v_inpaint_full_res_padding': 0,
|
||||
't2v_inpainting_mask_invert': False,
|
||||
|
||||
't2v_override_settings': [],
|
||||
#'t2v_script_inputs': [0],
|
||||
|
||||
't2v_fps': 12,
|
||||
}
|
||||
|
||||
args = list(args)
|
||||
|
||||
for i in range(len(args_list)):
|
||||
if (args[i] is None) and (args_list[i] in args_dict):
|
||||
#args[i] = args_dict[args_list[i]]
|
||||
pass
|
||||
else:
|
||||
args_dict[args_list[i]] = args[i]
|
||||
|
||||
args_dict['v2v_script_inputs'] = args[len(args_list):len(args_list)+shared.v2v_custom_inputs_size]
|
||||
#print('v2v_script_inputs', args_dict['v2v_script_inputs'])
|
||||
args_dict['t2v_script_inputs'] = args[len(args_list)+shared.v2v_custom_inputs_size:]
|
||||
#print('t2v_script_inputs', args_dict['t2v_script_inputs'])
|
||||
return args_dict
|
||||
|
||||
def get_mode_args(mode, args_dict):
|
||||
mode_args_dict = {}
|
||||
for key, value in args_dict.items():
|
||||
if key[:3] in [mode, 'glo'] :
|
||||
mode_args_dict[key[4:]] = value
|
||||
|
||||
return mode_args_dict
|
||||
|
||||
def set_CNs_input_image(args_dict, image, set_references = False):
|
||||
for script_input in args_dict['script_inputs']:
|
||||
if type(script_input).__name__ == 'UiControlNetUnit':
|
||||
if script_input.module not in ["reference_only", "reference_adain", "reference_adain+attn"] or set_references:
|
||||
script_input.image = np.array(image)
|
||||
script_input.batch_images = [np.array(image)]
|
||||
|
||||
import time
|
||||
import datetime
|
||||
|
||||
def get_time_left(ind, length, processing_start_time):
|
||||
s_passed = int(time.time() - processing_start_time)
|
||||
time_passed = datetime.timedelta(seconds=s_passed)
|
||||
s_left = int(s_passed / ind * (length - ind))
|
||||
time_left = datetime.timedelta(seconds=s_left)
|
||||
return f"Time elapsed: {time_passed}; Time left: {time_left};"
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
from types import SimpleNamespace
|
||||
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, process_images
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.images as images
|
||||
import modules.scripts
|
||||
from modules.shared import opts, state
|
||||
from modules import devices, sd_samplers, img2img
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
# TODO: Refactor all the code below
|
||||
|
||||
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
#images = shared.listfiles(input_dir)
|
||||
images = [input_img]
|
||||
|
||||
is_inpaint_batch = False
|
||||
#if inpaint_mask_dir:
|
||||
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
# is_inpaint_batch = len(inpaint_masks) > 0
|
||||
#if is_inpaint_batch:
|
||||
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
||||
p.do_not_save_grid = True
|
||||
p.do_not_save_samples = not save_normally
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
generated_images = []
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
img = image #Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
#if is_inpaint_batch:
|
||||
# # try to find corresponding mask for an image using simple filename matching
|
||||
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# # if not found use first one ("same mask for all images" use-case)
|
||||
# if not mask_image_path in inpaint_masks:
|
||||
# mask_image_path = inpaint_masks[0]
|
||||
# mask_image = Image.open(mask_image_path)
|
||||
# p.image_mask = mask_image
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
generated_images.append(proc.images[0])
|
||||
|
||||
#for n, processed_image in enumerate(proc.images):
|
||||
# filename = os.path.basename(image)
|
||||
|
||||
# if n > 0:
|
||||
# left, right = os.path.splitext(filename)
|
||||
# filename = f"{left}-{n}{right}"
|
||||
|
||||
# if not save_normally:
|
||||
# os.makedirs(output_dir, exist_ok=True)
|
||||
# if processed_image.mode == 'RGBA':
|
||||
# processed_image = processed_image.convert("RGB")
|
||||
# processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
return generated_images
|
||||
|
||||
def img2img(args_dict):
|
||||
args = SimpleNamespace(**args_dict)
|
||||
override_settings = create_override_settings_dict(args.override_settings)
|
||||
|
||||
is_batch = args.mode == 5
|
||||
|
||||
if args.mode == 0: # img2img
|
||||
image = args.init_img.convert("RGB")
|
||||
mask = None
|
||||
elif args.mode == 1: # img2img sketch
|
||||
image = args.sketch.convert("RGB")
|
||||
mask = None
|
||||
elif args.mode == 2: # inpaint
|
||||
image, mask = args.init_img_with_mask["image"], args.init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert("RGB")
|
||||
elif args.mode == 3: # inpaint sketch
|
||||
image = args.inpaint_color_sketch
|
||||
orig = args.inpaint_color_sketch_orig or args.inpaint_color_sketch
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - args.mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(args.mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif args.mode == 4: # inpaint upload mask
|
||||
#image = args.init_img_inpaint
|
||||
#mask = args.init_mask_inpaint
|
||||
|
||||
image = args.init_img.convert("RGB")
|
||||
mask = args.mask_img.convert("L")
|
||||
else:
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= args.denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.n_prompt,
|
||||
styles=args.prompt_styles,
|
||||
seed=args.seed,
|
||||
subseed=args.subseed,
|
||||
subseed_strength=args.subseed_strength,
|
||||
seed_resize_from_h=args.seed_resize_from_h,
|
||||
seed_resize_from_w=args.seed_resize_from_w,
|
||||
seed_enable_extras=args.seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[args.sampler_index].name,
|
||||
batch_size=args.batch_size,
|
||||
n_iter=args.n_iter,
|
||||
steps=args.steps,
|
||||
cfg_scale=args.cfg_scale,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
restore_faces=args.restore_faces,
|
||||
tiling=args.tiling,
|
||||
init_images=[image],
|
||||
mask=mask,
|
||||
mask_blur=args.mask_blur,
|
||||
inpainting_fill=args.inpainting_fill,
|
||||
resize_mode=args.resize_mode,
|
||||
denoising_strength=args.denoising_strength,
|
||||
image_cfg_scale=args.image_cfg_scale,
|
||||
inpaint_full_res=args.inpaint_full_res,
|
||||
inpaint_full_res_padding=args.inpaint_full_res_padding,
|
||||
inpainting_mask_invert=args.inpainting_mask_invert,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args.script_inputs
|
||||
|
||||
#if shared.cmd_opts.enable_console_prompts:
|
||||
# print(f"\nimg2img: {args.prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = args.mask_blur
|
||||
|
||||
'''
|
||||
if is_batch:
|
||||
...
|
||||
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
# process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args.script_inputs)
|
||||
# processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
'''
|
||||
|
||||
generated_images = process_img(p, image, None, '', args.script_inputs)
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
#if opts.samples_log_stdout:
|
||||
# print(generation_info_js)
|
||||
|
||||
#if opts.do_not_show_images:
|
||||
# processed.images = []
|
||||
|
||||
#print(generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments))
|
||||
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
||||
def txt2img(args_dict):
|
||||
args = SimpleNamespace(**args_dict)
|
||||
override_settings = create_override_settings_dict(args.override_settings)
|
||||
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=args.prompt,
|
||||
styles=args.prompt_styles,
|
||||
negative_prompt=args.n_prompt,
|
||||
seed=args.seed,
|
||||
subseed=args.subseed,
|
||||
subseed_strength=args.subseed_strength,
|
||||
seed_resize_from_h=args.seed_resize_from_h,
|
||||
seed_resize_from_w=args.seed_resize_from_w,
|
||||
seed_enable_extras=args.seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers[args.sampler_index].name,
|
||||
batch_size=args.batch_size,
|
||||
n_iter=args.n_iter,
|
||||
steps=args.steps,
|
||||
cfg_scale=args.cfg_scale,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
restore_faces=args.restore_faces,
|
||||
tiling=args.tiling,
|
||||
#enable_hr=args.enable_hr,
|
||||
#denoising_strength=args.denoising_strength if enable_hr else None,
|
||||
#hr_scale=hr_scale,
|
||||
#hr_upscaler=hr_upscaler,
|
||||
#hr_second_pass_steps=hr_second_pass_steps,
|
||||
#hr_resize_x=hr_resize_x,
|
||||
#hr_resize_y=hr_resize_y,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.script_args = args.script_inputs
|
||||
|
||||
#if cmd_opts.enable_console_prompts:
|
||||
# print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args.script_inputs)
|
||||
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
#if opts.samples_log_stdout:
|
||||
# print(generation_info_js)
|
||||
|
||||
#if opts.do_not_show_images:
|
||||
# processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
||||
|
||||
import json
|
||||
def get_json(obj):
|
||||
return json.loads(
|
||||
json.dumps(obj, default=lambda o: getattr(o, '__dict__', str(o)))
|
||||
)
|
||||
|
||||
def export_settings(*args):
|
||||
args_dict = args_to_dict(*args)
|
||||
if args[0] == 'vid2vid':
|
||||
args_dict = get_mode_args('v2v', args_dict)
|
||||
elif args[0] == 'txt2vid':
|
||||
args_dict = get_mode_args('t2v', args_dict)
|
||||
else:
|
||||
msg = f"Unsupported processing mode: '{args[0]}'"
|
||||
raise Exception(msg)
|
||||
|
||||
# convert CN params into a readable dict
|
||||
cn_remove_list = ['low_vram', 'is_ui', 'input_mode', 'batch_images', 'output_dir', 'loopback', 'image']
|
||||
|
||||
args_dict['ControlNets'] = []
|
||||
for script_input in args_dict['script_inputs']:
|
||||
if type(script_input).__name__ == 'UiControlNetUnit':
|
||||
cn_values_dict = get_json(script_input)
|
||||
if cn_values_dict['enabled']:
|
||||
for key in cn_remove_list:
|
||||
if key in cn_values_dict: del cn_values_dict[key]
|
||||
args_dict['ControlNets'].append(cn_values_dict)
|
||||
|
||||
# remove unimportant values
|
||||
remove_list = ['save_frames_check', 'restore_faces', 'prompt_styles', 'mask_blur', 'inpainting_fill', 'tiling', 'n_iter', 'batch_size', 'subseed', 'subseed_strength', 'seed_resize_from_h', \
|
||||
'seed_resize_from_w', 'seed_enable_extras', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'file', 'denoising_strength', \
|
||||
'override_settings', 'script_inputs', 'init_img', 'mask_img', 'mode', 'init_video']
|
||||
|
||||
for key in remove_list:
|
||||
if key in args_dict: del args_dict[key]
|
||||
|
||||
return json.dumps(args_dict, indent=2, default=lambda o: getattr(o, '__dict__', str(o)))
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
import sys, os
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
try:
|
||||
from modules import lowvram
|
||||
except ImportError:
|
||||
lowvram = None
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
import gc
|
||||
import cv2
|
||||
import gradio as gr
|
||||
|
||||
import time
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
from scripts.core.flow_utils import RAFT_estimate_flow, RAFT_clear_memory, compute_diff_map
|
||||
from scripts.core import utils
|
||||
|
||||
class sdcn_anim_tmp:
|
||||
prepear_counter = 0
|
||||
process_counter = 0
|
||||
input_video = None
|
||||
output_video = None
|
||||
curr_frame = None
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
prev_frame_alpha_mask = None
|
||||
fps = None
|
||||
total_frames = None
|
||||
prepared_frames = None
|
||||
prepared_next_flows = None
|
||||
prepared_prev_flows = None
|
||||
frames_prepared = False
|
||||
|
||||
def read_frame_from_video():
|
||||
# Reading video file
|
||||
if sdcn_anim_tmp.input_video.isOpened():
|
||||
ret, cur_frame = sdcn_anim_tmp.input_video.read()
|
||||
if cur_frame is not None:
|
||||
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
cur_frame = None
|
||||
sdcn_anim_tmp.input_video.release()
|
||||
|
||||
return cur_frame
|
||||
|
||||
def get_cur_stat():
|
||||
stat = f'Frames prepared: {sdcn_anim_tmp.prepear_counter + 1} / {sdcn_anim_tmp.total_frames}; '
|
||||
stat += f'Frames processed: {sdcn_anim_tmp.process_counter + 1} / {sdcn_anim_tmp.total_frames}; '
|
||||
return stat
|
||||
|
||||
def clear_memory_from_sd():
|
||||
if shared.sd_model is not None:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
if lowvram:
|
||||
try:
|
||||
lowvram.send_everything_to_cpu()
|
||||
except Exception as e:
|
||||
...
|
||||
del shared.sd_model
|
||||
shared.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
def start_process(*args):
|
||||
processing_start_time = time.time()
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||
|
||||
sdcn_anim_tmp.process_counter = 0
|
||||
sdcn_anim_tmp.prepear_counter = 0
|
||||
|
||||
# Open the input video file
|
||||
sdcn_anim_tmp.input_video = cv2.VideoCapture(args_dict['file'].name)
|
||||
|
||||
# Get useful info from the source video
|
||||
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
|
||||
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
loop_iterations = (sdcn_anim_tmp.total_frames-1) * 2
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
||||
output_video_folder = os.path.splitext(output_video_name)[0]
|
||||
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
|
||||
|
||||
if args_dict['save_frames_check']:
|
||||
os.makedirs(output_video_folder, exist_ok=True)
|
||||
|
||||
def save_result_to_image(image, ind):
|
||||
if args_dict['save_frames_check']:
|
||||
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
||||
|
||||
sdcn_anim_tmp.output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), sdcn_anim_tmp.fps, (args_dict['width'], args_dict['height']))
|
||||
|
||||
curr_frame = read_frame_from_video()
|
||||
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
||||
sdcn_anim_tmp.prepared_frames = np.zeros((11, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
|
||||
sdcn_anim_tmp.prepared_next_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
|
||||
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
|
||||
sdcn_anim_tmp.prepared_frames[0] = curr_frame
|
||||
|
||||
args_dict['init_img'] = Image.fromarray(curr_frame)
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
#print('Processed frame ', 0)
|
||||
|
||||
sdcn_anim_tmp.curr_frame = curr_frame
|
||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||
utils.shared.is_interrupted = False
|
||||
|
||||
save_result_to_image(processed_frame, 1)
|
||||
stat = get_cur_stat() + utils.get_time_left(1, loop_iterations, processing_start_time)
|
||||
yield stat, sdcn_anim_tmp.curr_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
for step in range(loop_iterations):
|
||||
if utils.shared.is_interrupted: break
|
||||
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||
|
||||
occlusion_mask = None
|
||||
prev_frame = None
|
||||
curr_frame = sdcn_anim_tmp.curr_frame
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
prepare_steps = 10
|
||||
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
|
||||
#clear_memory_from_sd()
|
||||
device = devices.get_optimal_device()
|
||||
|
||||
curr_frame = read_frame_from_video()
|
||||
if curr_frame is not None:
|
||||
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
||||
prev_frame = sdcn_anim_tmp.prev_frame.copy()
|
||||
|
||||
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, curr_frame, device=device)
|
||||
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
cn = sdcn_anim_tmp.prepear_counter % 10
|
||||
if sdcn_anim_tmp.prepear_counter % 10 == 0:
|
||||
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
|
||||
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
|
||||
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
|
||||
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
|
||||
#print('Prepared frame ', cn+1)
|
||||
|
||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||
|
||||
sdcn_anim_tmp.prepear_counter += 1
|
||||
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
|
||||
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
|
||||
curr_frame is None:
|
||||
# Remove RAFT from memory
|
||||
RAFT_clear_memory()
|
||||
sdcn_anim_tmp.frames_prepared = True
|
||||
else:
|
||||
# process frame
|
||||
sdcn_anim_tmp.frames_prepared = False
|
||||
|
||||
cn = sdcn_anim_tmp.process_counter % 10
|
||||
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1][...,:3]
|
||||
prev_frame = sdcn_anim_tmp.prepared_frames[cn][...,:3]
|
||||
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
|
||||
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
|
||||
|
||||
### STEP 1
|
||||
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled, args_dict)
|
||||
warped_styled_frame_ = warped_styled_frame.copy()
|
||||
|
||||
#fl_w, fl_h = prev_flow.shape[:2]
|
||||
#prev_flow_n = prev_flow / np.array([fl_h,fl_w])
|
||||
#flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None] * 20, 0, 1)
|
||||
#alpha_mask = alpha_mask * flow_mask
|
||||
|
||||
if sdcn_anim_tmp.process_counter > 0 and args_dict['occlusion_mask_trailing']:
|
||||
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
|
||||
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
|
||||
|
||||
# alpha_mask = np.round(alpha_mask * 8) / 8 #> 0.3
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
|
||||
warped_styled_frame = curr_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
||||
|
||||
# process current frame
|
||||
# TODO: convert args_dict into separate dict that stores only params necessery for img2img processing
|
||||
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
|
||||
img2img_args_dict['denoising_strength'] = args_dict['processing_strength']
|
||||
if args_dict['step_1_processing_mode'] == 0: # Process full image then blend in occlusions
|
||||
img2img_args_dict['mode'] = 0
|
||||
img2img_args_dict['mask_img'] = None #Image.fromarray(occlusion_mask)
|
||||
elif args_dict['step_1_processing_mode'] == 1: # Inpaint occlusions
|
||||
img2img_args_dict['mode'] = 4
|
||||
img2img_args_dict['mask_img'] = Image.fromarray(occlusion_mask)
|
||||
else:
|
||||
raise Exception('Incorrect step 1 processing mode!')
|
||||
|
||||
blend_alpha = args_dict['step_1_blend_alpha']
|
||||
init_img = warped_styled_frame * (1 - blend_alpha) + curr_frame * blend_alpha
|
||||
img2img_args_dict['init_img'] = Image.fromarray(np.clip(init_img, 0, 255).astype(np.uint8))
|
||||
img2img_args_dict['seed'] = args_dict['step_1_seed']
|
||||
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
|
||||
# normalizing the colors
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
||||
|
||||
#processed_frame = processed_frame * 0.94 + curr_frame * 0.06
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||
|
||||
### STEP 2
|
||||
if args_dict['fix_frame_strength'] > 0:
|
||||
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
|
||||
img2img_args_dict['mode'] = 0
|
||||
img2img_args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||
img2img_args_dict['mask_img'] = None
|
||||
img2img_args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||
img2img_args_dict['seed'] = args_dict['step_2_seed']
|
||||
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
|
||||
processed_frame = np.array(processed_frames[0])
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
|
||||
|
||||
# Write the frame to the output video
|
||||
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
|
||||
sdcn_anim_tmp.output_video.write(frame_out)
|
||||
|
||||
sdcn_anim_tmp.process_counter += 1
|
||||
#if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
|
||||
# sdcn_anim_tmp.input_video.release()
|
||||
# sdcn_anim_tmp.output_video.release()
|
||||
# sdcn_anim_tmp.prev_frame = None
|
||||
|
||||
save_result_to_image(processed_frame, sdcn_anim_tmp.process_counter + 1)
|
||||
|
||||
stat = get_cur_stat() + utils.get_time_left(step+2, loop_iterations+1, processing_start_time)
|
||||
yield stat, curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
RAFT_clear_memory()
|
||||
|
||||
sdcn_anim_tmp.input_video.release()
|
||||
sdcn_anim_tmp.output_video.release()
|
||||
|
||||
curr_frame = gr.Image.update()
|
||||
occlusion_mask = gr.Image.update()
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||