Compare commits
107 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
05e265ff3f | |
|
|
f2709f3990 | |
|
|
2e5e09f3c3 | |
|
|
052fdec082 | |
|
|
4f07b5e6d2 | |
|
|
ab1f1583b7 | |
|
|
48d8348de4 | |
|
|
3ccf7373ad | |
|
|
de1ff473f7 | |
|
|
6a04d70241 | |
|
|
2e257bbfc3 | |
|
|
111711fc7b | |
|
|
8bc0954ac8 | |
|
|
c3c0972c0a | |
|
|
89ba89d949 | |
|
|
ba3c17ef7e | |
|
|
b26628a5be | |
|
|
3c36b8e7c5 | |
|
|
c3e4b42d98 | |
|
|
d97ead6ab8 | |
|
|
00fbf01831 | |
|
|
9e08b4c7d3 | |
|
|
8d63dd5471 | |
|
|
e16f728512 | |
|
|
a35f446b69 | |
|
|
9849e6389e | |
|
|
cd400ea7b1 | |
|
|
5d1b65ee48 | |
|
|
9ab15587d8 | |
|
|
e67bcb9264 | |
|
|
f1d47c954a | |
|
|
4ef0097751 | |
|
|
98fe91ceae | |
|
|
33897b7e2f | |
|
|
027faf1612 | |
|
|
63efd3e0e3 | |
|
|
b9d080ff4e | |
|
|
be9f056657 | |
|
|
eddf1a4c8f | |
|
|
ca5fdb1151 | |
|
|
9c5611355d | |
|
|
85cd721e83 | |
|
|
19ac530bb0 | |
|
|
ea0b5e19fc | |
|
|
46ae16e4cb | |
|
|
14534d3174 | |
|
|
c987f645d6 | |
|
|
f33a508d3f | |
|
|
4adcaf026b | |
|
|
a553de32db | |
|
|
31a8dc71d8 | |
|
|
ef5dca6d99 | |
|
|
eff3046d81 | |
|
|
115dbeacb1 | |
|
|
9b88a8f04e | |
|
|
7e55f4d781 | |
|
|
82d0679a51 | |
|
|
bcfd6f994d | |
|
|
50724e8056 | |
|
|
bb4353a264 | |
|
|
c34cfe2976 | |
|
|
dc2be7ba28 | |
|
|
0cb020b157 | |
|
|
3c69efe7a8 | |
|
|
f46b73b6f4 | |
|
|
f073b25440 | |
|
|
10430c5d0f | |
|
|
3e0de3b84d | |
|
|
12dcbef475 | |
|
|
b112d6b2fc | |
|
|
bccf317b03 | |
|
|
0b7fb7c252 | |
|
|
a8c298b7e4 | |
|
|
6a6ca687a5 | |
|
|
05260bce59 | |
|
|
fac580c3b9 | |
|
|
ae605b4299 | |
|
|
621e18ea56 | |
|
|
7e051bbe13 | |
|
|
cb737361d0 | |
|
|
81c425a429 | |
|
|
22bcebb64e | |
|
|
22d909ac59 | |
|
|
5d2f58151d | |
|
|
13a39359ad | |
|
|
0e8039a6d2 | |
|
|
c78c9ae907 | |
|
|
729909e533 | |
|
|
c0e876f4b5 | |
|
|
453bf74188 | |
|
|
e5fd243585 | |
|
|
085f865658 | |
|
|
87834fc192 | |
|
|
3ce461efcc | |
|
|
8287c45e31 | |
|
|
6dda2aaf9d | |
|
|
e90e7e3690 | |
|
|
b65ad41836 | |
|
|
24f77ed2fc | |
|
|
7f32fc1f9b | |
|
|
02936b0127 | |
|
|
d8f4774706 | |
|
|
d272fc1be9 | |
|
|
e8a79ee126 | |
|
|
4f052e83f9 | |
|
|
49ae994d88 | |
|
|
c8cbcf2cf4 |
|
|
@ -0,0 +1,6 @@
|
|||
__pycache__/
|
||||
out/
|
||||
videos/
|
||||
FP_Res/
|
||||
result.mp4
|
||||
*.pth
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
|
||||
# Define the model
|
||||
class FloweR(nn.Module):
|
||||
def __init__(self, input_size = (384, 384), window_size = 4):
|
||||
super(FloweR, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.window_size = window_size
|
||||
|
||||
# 2 channels for optical flow
|
||||
# 1 channel for occlusion mask
|
||||
# 3 channels for next frame prediction
|
||||
self.out_channels = 6
|
||||
|
||||
|
||||
#INPUT: 384 x 384 x 4 * 3
|
||||
|
||||
### DOWNSCALE ###
|
||||
self.conv_block_1 = nn.Sequential(
|
||||
nn.Conv2d(3 * self.window_size, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 384 x 384 x 128
|
||||
|
||||
self.conv_block_2 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 192 x 192 x 128
|
||||
|
||||
self.conv_block_3 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 96 x 96 x 128
|
||||
|
||||
self.conv_block_4 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 48 x 48 x 128
|
||||
|
||||
self.conv_block_5 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 24 x 24 x 128
|
||||
|
||||
self.conv_block_6 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 12 x 12 x 128
|
||||
|
||||
self.conv_block_7 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 6 x 6 x 128
|
||||
|
||||
self.conv_block_8 = nn.Sequential(
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 3 x 3 x 128 - 9 input tokens
|
||||
|
||||
### Transformer part ###
|
||||
# To be done
|
||||
|
||||
### UPSCALE ###
|
||||
self.conv_block_9 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 6 x 6 x 128
|
||||
|
||||
self.conv_block_10 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 12 x 12 x 128
|
||||
|
||||
self.conv_block_11 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 24 x 24 x 128
|
||||
|
||||
self.conv_block_12 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 48 x 48 x 128
|
||||
|
||||
self.conv_block_13 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 96 x 96 x 128
|
||||
|
||||
self.conv_block_14 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 192 x 192 x 128
|
||||
|
||||
self.conv_block_15 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
|
||||
nn.ReLU(),
|
||||
) # 384 x 384 x 128
|
||||
|
||||
self.conv_block_16 = nn.Conv2d(128, self.out_channels, kernel_size=3, stride=1, padding='same')
|
||||
|
||||
def forward(self, input_frames):
|
||||
|
||||
if input_frames.size(1) != self.window_size:
|
||||
raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.')
|
||||
|
||||
h, w = self.input_size
|
||||
# batch, frames, height, width, colors
|
||||
input_frames_permuted = input_frames.permute((0, 1, 4, 2, 3))
|
||||
# batch, frames, colors, height, width
|
||||
|
||||
in_x = input_frames_permuted.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1])
|
||||
|
||||
### DOWNSCALE ###
|
||||
block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128
|
||||
block_2_out = self.conv_block_2(block_1_out) # 192 x 192 x 128
|
||||
block_3_out = self.conv_block_3(block_2_out) # 96 x 96 x 128
|
||||
block_4_out = self.conv_block_4(block_3_out) # 48 x 48 x 128
|
||||
block_5_out = self.conv_block_5(block_4_out) # 24 x 24 x 128
|
||||
block_6_out = self.conv_block_6(block_5_out) # 12 x 12 x 128
|
||||
block_7_out = self.conv_block_7(block_6_out) # 6 x 6 x 128
|
||||
block_8_out = self.conv_block_8(block_7_out) # 3 x 3 x 128
|
||||
|
||||
### UPSCALE ###
|
||||
block_9_out = block_7_out + self.conv_block_9(block_8_out) # 6 x 6 x 128
|
||||
block_10_out = block_6_out + self.conv_block_10(block_9_out) # 12 x 12 x 128
|
||||
block_11_out = block_5_out + self.conv_block_11(block_10_out) # 24 x 24 x 128
|
||||
block_12_out = block_4_out + self.conv_block_12(block_11_out) # 48 x 48 x 128
|
||||
block_13_out = block_3_out + self.conv_block_13(block_12_out) # 96 x 96 x 128
|
||||
block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128
|
||||
block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128
|
||||
|
||||
block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1 + 3)
|
||||
out = block_16_out.reshape(-1, self.out_channels, self.input_size[0], self.input_size[1])
|
||||
|
||||
### for future model training ###
|
||||
device = out.get_device()
|
||||
|
||||
pred_flow = out[:,:2,:,:] * 255 # (-255, 255)
|
||||
pred_occl = (out[:,2:3,:,:] + 1) / 2 # [0, 1]
|
||||
pred_next = out[:,3:6,:,:]
|
||||
|
||||
# Generate sampling grids
|
||||
|
||||
# Create grid to upsample input
|
||||
'''
|
||||
d = torch.linspace(-1, 1, 8)
|
||||
meshx, meshy = torch.meshgrid((d, d))
|
||||
grid = torch.stack((meshy, meshx), 2)
|
||||
grid = grid.unsqueeze(0) '''
|
||||
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid = flow_grid.unsqueeze(0).to(device=device)
|
||||
flow_grid = flow_grid + pred_flow
|
||||
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
# batch, flow_chanels, height, width
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
# batch, height, width, flow_chanels
|
||||
|
||||
previous_frame = input_frames_permuted[:, -1, :, :, :]
|
||||
sampling_mode = "bilinear" if self.training else "nearest"
|
||||
warped_frame = torch.nn.functional.grid_sample(previous_frame, flow_grid, mode=sampling_mode, padding_mode="reflection", align_corners=False)
|
||||
alpha_mask = torch.clip(pred_occl * 10, 0, 1) * 0.04
|
||||
pred_next = torch.clip(pred_next, -1, 1)
|
||||
warped_frame = torch.clip(warped_frame, -1, 1)
|
||||
next_frame = pred_next * alpha_mask + warped_frame * (1 - alpha_mask)
|
||||
|
||||
res = torch.cat((pred_flow / 255, pred_occl * 2 - 1, next_frame), dim=1)
|
||||
|
||||
# batch, channels, height, width
|
||||
res = res.permute((0, 2, 3, 1))
|
||||
# batch, height, width, channels
|
||||
return res
|
||||
1
RAFT
|
|
@ -1 +0,0 @@
|
|||
Subproject commit aac9dd54726caf2cf81d8661b07663e220c5586d
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2020, princeton-vl
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from RAFT.utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels-1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
def __call__(self, coords):
|
||||
r = self.radius
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||
|
||||
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from RAFT.update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from RAFT.extractor import BasicEncoder, SmallEncoder
|
||||
from RAFT.corr import CorrBlock, AlternateCorrBlock
|
||||
from RAFT.utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class RAFT(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(RAFT, self).__init__()
|
||||
self.args = args
|
||||
|
||||
if args.small:
|
||||
self.hidden_dim = hdim = 96
|
||||
self.context_dim = cdim = 64
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 3
|
||||
|
||||
else:
|
||||
self.hidden_dim = hdim = 128
|
||||
self.context_dim = cdim = 128
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
||||
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
else:
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def initialize_flow(self, img):
|
||||
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||
N, C, H, W = img.shape
|
||||
coords0 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
coords1 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8*H, 8*W)
|
||||
|
||||
|
||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
cnet = self.cnet(image1)
|
||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
|
||||
coords0, coords1 = self.initialize_flow(image1)
|
||||
|
||||
if flow_init is not None:
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
flow_predictions = []
|
||||
for itr in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
corr = corr_fn(coords1) # index correlation volume
|
||||
|
||||
flow = coords1 - coords0
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
|
||||
# F(t+1) = F(t) + \Delta(t)
|
||||
coords1 = coords1 + delta_flow
|
||||
|
||||
# upsample predictions
|
||||
if up_mask is None:
|
||||
flow_up = upflow8(coords1 - coords0)
|
||||
else:
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
if test_mode:
|
||||
return coords1 - coords0, flow_up
|
||||
|
||||
return flow_predictions
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
||||
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
|
||||
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
class SmallMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(SmallMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=96):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
self.encoder = SmallMotionEncoder(args)
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, None, delta_flow
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64*9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
||||
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
""" Photometric augmentation """
|
||||
|
||||
# asymmetric
|
||||
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||
|
||||
# symmetric
|
||||
else:
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||
""" Occlusion augmentation """
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(bounds[0], bounds[1])
|
||||
dy = np.random.randint(bounds[0], bounds[1])
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def spatial_transform(self, img1, img2, flow):
|
||||
# randomly sample scale
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 8) / float(ht),
|
||||
(self.crop_size[1] + 8) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
if np.random.rand() < self.stretch_prob:
|
||||
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
|
||||
scale_x = np.clip(scale_x, min_scale, None)
|
||||
scale_y = np.clip(scale_y, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = flow * [scale_x, scale_y]
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < self.h_flip_prob: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
|
||||
if np.random.rand() < self.v_flip_prob: # v-flip
|
||||
img1 = img1[::-1, :]
|
||||
img2 = img2[::-1, :]
|
||||
flow = flow[::-1, :] * [1.0, -1.0]
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
def __call__(self, img1, img2, flow):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
class SparseFlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2):
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(50, 100)
|
||||
dy = np.random.randint(50, 100)
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
||||
ht, wd = flow.shape[:2]
|
||||
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
coords = np.stack(coords, axis=-1)
|
||||
|
||||
coords = coords.reshape(-1, 2).astype(np.float32)
|
||||
flow = flow.reshape(-1, 2).astype(np.float32)
|
||||
valid = valid.reshape(-1).astype(np.float32)
|
||||
|
||||
coords0 = coords[valid>=1]
|
||||
flow0 = flow[valid>=1]
|
||||
|
||||
ht1 = int(round(ht * fy))
|
||||
wd1 = int(round(wd * fx))
|
||||
|
||||
coords1 = coords0 * [fx, fy]
|
||||
flow1 = flow0 * [fx, fy]
|
||||
|
||||
xx = np.round(coords1[:,0]).astype(np.int32)
|
||||
yy = np.round(coords1[:,1]).astype(np.int32)
|
||||
|
||||
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
||||
xx = xx[v]
|
||||
yy = yy[v]
|
||||
flow1 = flow1[v]
|
||||
|
||||
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
||||
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
||||
|
||||
flow_img[yy, xx] = flow1
|
||||
valid_img[yy, xx] = 1
|
||||
|
||||
return flow_img, valid_img
|
||||
|
||||
def spatial_transform(self, img1, img2, flow, valid):
|
||||
# randomly sample scale
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 1) / float(ht),
|
||||
(self.crop_size[1] + 1) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = np.clip(scale, min_scale, None)
|
||||
scale_y = np.clip(scale, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < 0.5: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
valid = valid[:, ::-1]
|
||||
|
||||
margin_y = 20
|
||||
margin_x = 50
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
||||
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
||||
|
||||
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
return img1, img2, flow, valid
|
||||
|
||||
|
||||
def __call__(self, img1, img2, flow, valid):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
valid = np.ascontiguousarray(valid)
|
||||
|
||||
return img1, img2, flow, valid
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2018 Tom Runia
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to conditions.
|
||||
#
|
||||
# Author: Tom Runia
|
||||
# Date Created: 2018-08-03
|
||||
|
||||
import numpy as np
|
||||
|
||||
def make_colorwheel():
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = np.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
||||
col = col+RY
|
||||
# YG
|
||||
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
||||
colorwheel[col:col+YG, 1] = 255
|
||||
col = col+YG
|
||||
# GC
|
||||
colorwheel[col:col+GC, 1] = 255
|
||||
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
||||
col = col+GC
|
||||
# CB
|
||||
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
||||
colorwheel[col:col+CB, 2] = 255
|
||||
col = col+CB
|
||||
# BM
|
||||
colorwheel[col:col+BM, 2] = 255
|
||||
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
||||
col = col+BM
|
||||
# MR
|
||||
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
||||
colorwheel[col:col+MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u)/np.pi
|
||||
fk = (a+1) / 2*(ncols-1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
for i in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:,i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1-f)*col0 + f*col1
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2-i if convert_to_bgr else i
|
||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
u = flow_uv[:,:,0]
|
||||
v = flow_uv[:,:,1]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from os.path import *
|
||||
import re
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
TAG_CHAR = np.array([202021.25], np.float32)
|
||||
|
||||
def readFlow(fn):
|
||||
""" Read .flo file in Middlebury format"""
|
||||
# Code adapted from:
|
||||
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
||||
|
||||
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
||||
# print 'fn = %s'%(fn)
|
||||
with open(fn, 'rb') as f:
|
||||
magic = np.fromfile(f, np.float32, count=1)
|
||||
if 202021.25 != magic:
|
||||
print('Magic number incorrect. Invalid .flo file')
|
||||
return None
|
||||
else:
|
||||
w = np.fromfile(f, np.int32, count=1)
|
||||
h = np.fromfile(f, np.int32, count=1)
|
||||
# print 'Reading %d x %d flo file\n' % (w, h)
|
||||
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
||||
# Reshape data into 3D array (columns, rows, bands)
|
||||
# The reshape here is for visualization, the original code is (w,h,2)
|
||||
return np.resize(data, (int(h), int(w), 2))
|
||||
|
||||
def readPFM(file):
|
||||
file = open(file, 'rb')
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header == b'PF':
|
||||
color = True
|
||||
elif header == b'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
return data
|
||||
|
||||
def writeFlow(filename,uv,v=None):
|
||||
""" Write optical flow to file.
|
||||
|
||||
If v is None, uv is assumed to contain both u and v channels,
|
||||
stacked in depth.
|
||||
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
||||
"""
|
||||
nBands = 2
|
||||
|
||||
if v is None:
|
||||
assert(uv.ndim == 3)
|
||||
assert(uv.shape[2] == 2)
|
||||
u = uv[:,:,0]
|
||||
v = uv[:,:,1]
|
||||
else:
|
||||
u = uv
|
||||
|
||||
assert(u.shape == v.shape)
|
||||
height,width = u.shape
|
||||
f = open(filename,'wb')
|
||||
# write the header
|
||||
f.write(TAG_CHAR)
|
||||
np.array(width).astype(np.int32).tofile(f)
|
||||
np.array(height).astype(np.int32).tofile(f)
|
||||
# arrange into matrix form
|
||||
tmp = np.zeros((height, width*nBands))
|
||||
tmp[:,np.arange(width)*2] = u
|
||||
tmp[:,np.arange(width)*2 + 1] = v
|
||||
tmp.astype(np.float32).tofile(f)
|
||||
f.close()
|
||||
|
||||
|
||||
def readFlowKITTI(filename):
|
||||
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
||||
flow = flow[:,:,::-1].astype(np.float32)
|
||||
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
||||
flow = (flow - 2**15) / 64.0
|
||||
return flow, valid
|
||||
|
||||
def readDispKITTI(filename):
|
||||
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||
valid = disp > 0.0
|
||||
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
||||
return flow, valid
|
||||
|
||||
|
||||
def writeFlowKITTI(filename, uv):
|
||||
uv = 64.0 * uv + 2**15
|
||||
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
||||
cv2.imwrite(filename, uv[..., ::-1])
|
||||
|
||||
|
||||
def read_gen(file_name, pil=False):
|
||||
ext = splitext(file_name)[-1]
|
||||
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
||||
return Image.open(file_name)
|
||||
elif ext == '.bin' or ext == '.raw':
|
||||
return np.load(file_name)
|
||||
elif ext == '.flo':
|
||||
return readFlow(file_name).astype(np.float32)
|
||||
elif ext == '.pfm':
|
||||
flow = readPFM(file_name).astype(np.float32)
|
||||
if len(flow.shape) == 2:
|
||||
return flow
|
||||
else:
|
||||
return flow[:, :, :-1]
|
||||
return []
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||
else:
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
def unpad(self,x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
dx, dy = flow[0], flow[1]
|
||||
|
||||
ht, wd = dx.shape
|
||||
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
|
||||
x1 = x0 + dx
|
||||
y1 = y0 + dy
|
||||
|
||||
x1 = x1.reshape(-1)
|
||||
y1 = y1.reshape(-1)
|
||||
dx = dx.reshape(-1)
|
||||
dy = dy.reshape(-1)
|
||||
|
||||
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||
x1 = x1[valid]
|
||||
y1 = y1[valid]
|
||||
dx = dx[valid]
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def upflow8(flow, mode='bilinear'):
|
||||
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
|
After Width: | Height: | Size: 902 KiB |
|
After Width: | Height: | Size: 89 KiB |
|
After Width: | Height: | Size: 451 KiB |
|
After Width: | Height: | Size: 1.6 MiB |
|
After Width: | Height: | Size: 3.0 MiB |
|
After Width: | Height: | Size: 3.2 MiB |
|
After Width: | Height: | Size: 3.2 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 1.5 MiB |
|
After Width: | Height: | Size: 865 KiB |
|
|
@ -0,0 +1,20 @@
|
|||
import launch
|
||||
import os
|
||||
import pkg_resources
|
||||
|
||||
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||
|
||||
with open(req_file) as file:
|
||||
for package in file:
|
||||
try:
|
||||
package = package.strip()
|
||||
if '==' in package:
|
||||
package_name, package_version = package.split('==')
|
||||
installed_version = pkg_resources.get_distribution(package_name).version
|
||||
if installed_version != package_version:
|
||||
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: changing {package_name} version from {installed_version} to {package_version}")
|
||||
elif not launch.is_installed(package):
|
||||
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: {package}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f'Warning: Failed to install {package}.')
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from flow_utils import RAFT_estimate_flow
|
||||
import h5py
|
||||
|
||||
import argparse
|
||||
|
||||
def main(args):
|
||||
W, H = args.width, args.height
|
||||
# Open the input video file
|
||||
input_video = cv2.VideoCapture(args.input_video)
|
||||
|
||||
# Get useful info from the source video
|
||||
fps = int(input_video.get(cv2.CAP_PROP_FPS))
|
||||
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
prev_frame = None
|
||||
|
||||
# create an empty HDF5 file
|
||||
with h5py.File(args.output_file, 'w') as f: pass
|
||||
|
||||
# open the file for writing a flow maps into it
|
||||
with h5py.File(args.output_file, 'a') as f:
|
||||
flow_maps = f.create_dataset('flow_maps', shape=(0, 2, H, W, 2), maxshape=(None, 2, H, W, 2), dtype=np.float16)
|
||||
|
||||
for ind in tqdm(range(total_frames)):
|
||||
# Read the next frame from the input video
|
||||
if not input_video.isOpened(): break
|
||||
ret, cur_frame = input_video.read()
|
||||
if not ret: break
|
||||
|
||||
cur_frame = cv2.resize(cur_frame, (W, H))
|
||||
|
||||
if prev_frame is not None:
|
||||
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, cur_frame, subtract_background=args.remove_background)
|
||||
|
||||
# write data into a file
|
||||
flow_maps.resize(ind, axis=0)
|
||||
flow_maps[ind-1, 0] = next_flow
|
||||
flow_maps[ind-1, 1] = prev_flow
|
||||
|
||||
occlusion_mask = np.clip(occlusion_mask * 0.2 * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
if args.visualize:
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
if args.remove_background:
|
||||
img_show = cv2.hconcat([cur_frame, frame2_bg_removed, occlusion_mask])
|
||||
else:
|
||||
img_show = cv2.hconcat([cur_frame, occlusion_mask])
|
||||
cv2.imshow('Out img', img_show)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
prev_frame = cur_frame.copy()
|
||||
|
||||
# Release the input and output video files
|
||||
input_video.release()
|
||||
|
||||
# Close all windows
|
||||
if args.visualize: cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
|
||||
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
|
||||
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
|
||||
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
|
||||
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
|
||||
parser.add_argument('-rb', '--remove_background', action='store_true', help='Remove background of the image')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
|
||||
# RAFT dependencies
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder
|
||||
|
||||
RAFT_model = None
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
|
||||
|
||||
def background_subtractor(frame, fgbg):
|
||||
fgmask = fgbg.apply(frame)
|
||||
return cv2.bitwise_and(frame, frame, mask=fgmask)
|
||||
|
||||
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
|
||||
global RAFT_model
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': 'RAFT/models/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
if subtract_background:
|
||||
frame1 = background_subtractor(frame1, fgbg)
|
||||
frame2 = background_subtractor(frame2, fgbg)
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
||||
|
||||
return next_flow, prev_flow, occlusion_mask, frame1, frame2
|
||||
|
||||
# ... rest of the file ...
|
||||
|
||||
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
|
||||
h, w = cur_frame.shape[:2]
|
||||
|
||||
#print(np.amin(next_flow), np.amax(next_flow))
|
||||
#exit()
|
||||
|
||||
|
||||
fl_w, fl_h = next_flow.shape[:2]
|
||||
|
||||
# normalize flow
|
||||
next_flow = next_flow / np.array([fl_h,fl_w])
|
||||
prev_flow = prev_flow / np.array([fl_h,fl_w])
|
||||
|
||||
# remove low value noise (@alexfredo suggestion)
|
||||
next_flow[np.abs(next_flow) < 0.05] = 0
|
||||
prev_flow[np.abs(prev_flow) < 0.05] = 0
|
||||
|
||||
# resize flow
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
|
||||
|
||||
# Generate sampling grids
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
|
||||
flow_grid = flow_grid.unsqueeze(0)
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
|
||||
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
|
||||
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
|
||||
|
||||
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (51,51), 5, cv2.BORDER_REFLECT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
||||
def frames_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def flow_norm(flow): return flow / 255
|
||||
|
||||
def occl_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def flow_renorm(flow): return flow * 255
|
||||
|
||||
def occl_renorm(occl): return (occl + 1) * 127.5
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
# SD-CN-Animation
|
||||
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
|
||||
|
||||
|
||||
### Video to Video Examples:
|
||||
<!--
|
||||
[](https://youtu.be/j-0niEMm6DU)
|
||||
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
|
||||
-->
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/girl_org.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_jc.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_wc.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">Original video</td>
|
||||
<td width=33% align="center">"Jessica Chastain"</td>
|
||||
<td width=33% align="center">"Watercolor painting"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
|
||||
|
||||
### Text to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/flower_1.gif" raw=true></td>
|
||||
<td><img src="examples/bonfire_1.gif" raw=true></td>
|
||||
<td><img src="examples/diamond_4.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of a flower"</td>
|
||||
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
|
||||
<td width=33% align="center">"close up of a diamond laying on the table"</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="examples/macaroni_1.gif" raw=true></td>
|
||||
<td><img src="examples/gold_1.gif" raw=true></td>
|
||||
<td><img src="examples/tree_2.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of macaroni on the plate"</td>
|
||||
<td width=33% align="center">"close up of golden sphere"</td>
|
||||
<td width=33% align="center">"a tree standing in the winter forest"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
|
||||
|
||||
|
||||
|
||||
## Dependencies
|
||||
To install all the necessary dependencies, run this command:
|
||||
```
|
||||
pip install opencv-python opencv-contrib-python numpy tqdm h5py scikit-image
|
||||
```
|
||||
You have to set up the RAFT repository as it described here: https://github.com/princeton-vl/RAFT . Basically it just comes down to running "./download_models.sh" in RAFT folder to download the models.
|
||||
|
||||
|
||||
## Running the scripts
|
||||
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You should also have[sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have the control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly, you have to also allow the API to work with controlNet. To do so, go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings".
|
||||
|
||||
|
||||
### Video To Video
|
||||
#### Step 1.
|
||||
To process the video, first of all you would need to precompute optical flow data before running web-ui with this command:
|
||||
```
|
||||
python3 compute_flow.py -i "path to your video" -o "path to output file with *.h5 format" -v -W width_of_the_flow_map -H height_of_the_flow_map
|
||||
```
|
||||
The main reason to do this step separately is to save precious GPU memory that will be useful to generate better quality images. Choose W and H parameters as high as your GPU can handle with respect to the proportion of original video resolution. Do not worry if it is higher or less then the processing resolution, flow maps will be scaled accordingly at the processing stage. This will generate quite a large file that may take up to a several gigabytes on the drive even for minute long video. If you want to process a long video consider splitting it into several parts beforehand.
|
||||
|
||||
|
||||
#### Step 2.
|
||||
Run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```
|
||||
bash webui.sh --xformers --api
|
||||
```
|
||||
|
||||
|
||||
#### Step 3.
|
||||
Go to the **vid2vid.py** file and change main parameters (INPUT_VIDEO, FLOW_MAPS, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. FLOW_MAPS parameter should contain a path to the flow file that you generated at the first step. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time. Finally run the script with the command:
|
||||
```
|
||||
python3 vid2vid.py
|
||||
```
|
||||
|
||||
|
||||
### Text To Video
|
||||
This method is still in development and works on top of ‘Stable Diffusion’ and 'FloweR' - optical flow reconstruction method that is also in a yearly development stage. Do not expect much from it as it is more of a proof of a concept rather than a complete solution.
|
||||
|
||||
#### Step 1.
|
||||
Download 'FloweR_0.1.pth' model from here: [Google drive link](https://drive.google.com/file/d/1WhzoVIw6Kdg4EjfK9LaTLqFm5dF-IJ7F/view?usp=share_link) and place it in the 'FloweR' folder.
|
||||
|
||||
#### Step 2.
|
||||
Same as with vid2vid case, run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```
|
||||
bash webui.sh --xformers --api
|
||||
```
|
||||
|
||||
#### Step 3.
|
||||
Go to the **txt2vid.py** file and change main parameters (OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. Again, the script is simple so you may change other parameters if you want to. Finally run the script with the command:
|
||||
```
|
||||
python3 txt2vid.py
|
||||
```
|
||||
|
||||
## Last version changes: v0.5
|
||||
* Fixed an issue with the wrong direction of an optical flow applied to an image.
|
||||
* Added text to video mode within txt2vid.py script. Make sure to update new dependencies for this script to work!
|
||||
* Added a threshold for an optical flow before processing the frame to remove white noise that might appear, as it was suggested by [@alexfredo](https://github.com/alexfredo).
|
||||
* Background removal at flow computation stage implemented by [@CaptnSeraph](https://github.com/CaptnSeraph), it should reduce ghosting effect in most of the videos processed with vid2vid script.
|
||||
|
||||
<!--
|
||||
## Last version changes: v0.6
|
||||
* Added separate flag '-rb' for background removal process at the flow computation stage in the compute_flow.py script.
|
||||
* Added flow normalization before rescaling it, so the magnitude of the flow computed correctly at the different resolution.
|
||||
* Less ghosting and color change in vid2vid mode
|
||||
-->
|
||||
|
||||
<!--
|
||||
## Potential improvements
|
||||
There are several ways overall quality of animation may be improved:
|
||||
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
|
||||
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
|
||||
* The quality of flow estimation might be greatly improved with a proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
|
||||
-->
|
||||
## Licence
|
||||
This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact me directly at borsky.alexey@gmail.com
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
import requests
|
||||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
import sys
|
||||
sys.path.append('FloweR/')
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
import torch
|
||||
from model import FloweR
|
||||
from utils import flow_viz
|
||||
|
||||
from flow_utils import *
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
|
||||
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
|
||||
|
||||
PROMPT = "people looking at flying robots. Future. People looking to the sky. Stars in the background. Dramatic light, Cinematic light. Soft lighting, high quality, film grain."
|
||||
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
w,h = 768, 512 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
|
||||
|
||||
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
|
||||
|
||||
PROCESSING_STRENGTH = 0.85
|
||||
FIX_STRENGTH = 0.35
|
||||
|
||||
CFG_SCALE = 5.5
|
||||
|
||||
APPLY_TEMPORALNET = False
|
||||
APPLY_COLOR = False
|
||||
|
||||
VISUALIZE = True
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def to_b64(img):
|
||||
img_cliped = np.clip(img, 0, 255).astype(np.uint8)
|
||||
_, buffer = cv2.imencode('.png', img_cliped)
|
||||
b64img = base64.b64encode(buffer).decode("utf-8")
|
||||
return b64img
|
||||
|
||||
class controlnetRequest():
|
||||
def __init__(self, b64_init_img = None, b64_prev_img = None, b64_color_img = None, ds = 0.35, w=w, h=h, mask = None, seed=-1, mode='img2img'):
|
||||
self.url = f"http://localhost:7860/sdapi/v1/{mode}"
|
||||
self.body = {
|
||||
"init_images": [b64_init_img],
|
||||
"mask": mask,
|
||||
"mask_blur": 0,
|
||||
"inpainting_fill": 1,
|
||||
"inpainting_mask_invert": 0,
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": N_PROMPT,
|
||||
"seed": seed,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 15,
|
||||
"cfg_scale": CFG_SCALE,
|
||||
"denoising_strength": ds,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"restore_faces": False,
|
||||
"eta": 0,
|
||||
"sampler_index": "DPM++ 2S a",
|
||||
"control_net_enabled": True,
|
||||
"alwayson_scripts": {
|
||||
"ControlNet":{"args": []}
|
||||
},
|
||||
}
|
||||
|
||||
if APPLY_TEMPORALNET:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_prev_img,
|
||||
"module": "none",
|
||||
"model": "diff_control_sd15_temporalnet_fp16 [adc6bd97]",
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": False
|
||||
})
|
||||
|
||||
if APPLY_COLOR:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_prev_img,
|
||||
"module": "color",
|
||||
"model": "t2iadapter_color_sd14v1 [8522029d]",
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": False
|
||||
})
|
||||
|
||||
|
||||
def sendRequest(self):
|
||||
# Request to web-ui
|
||||
data_js = requests.post(self.url, json=self.body).json()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
return out_image
|
||||
|
||||
|
||||
|
||||
if VISUALIZE: cv2.namedWindow('Out img')
|
||||
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), 15, (w, h))
|
||||
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
|
||||
|
||||
# Instantiate the model
|
||||
model = FloweR(input_size = (h, w))
|
||||
model.load_state_dict(torch.load('FloweR/FloweR_0.1.1.pth'))
|
||||
# Move the model to the device
|
||||
model = model.to(DEVICE)
|
||||
|
||||
|
||||
init_frame = controlnetRequest(mode='txt2img', ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
output_video.write(init_frame)
|
||||
prev_frame = init_frame
|
||||
|
||||
clip_frames = np.zeros((4, h, w, 3), dtype=np.uint8)
|
||||
|
||||
color_shift = np.zeros((0, 3))
|
||||
color_scale = np.zeros((0, 3))
|
||||
for ind in tqdm(range(450)):
|
||||
clip_frames = np.roll(clip_frames, -1, axis=0)
|
||||
clip_frames[-1] = prev_frame
|
||||
|
||||
clip_frames_torch = frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
||||
|
||||
with torch.no_grad():
|
||||
pred_data = model(clip_frames_torch.unsqueeze(0))[0]
|
||||
|
||||
pred_flow = flow_renorm(pred_data[...,:2]).cpu().numpy()
|
||||
pred_occl = occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
||||
|
||||
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
||||
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
||||
|
||||
|
||||
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
||||
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
||||
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
||||
|
||||
flow_map = pred_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(w)
|
||||
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
|
||||
|
||||
out_image = warped_frame.copy()
|
||||
|
||||
out_image = controlnetRequest(
|
||||
b64_init_img = to_b64(out_image),
|
||||
b64_prev_img = to_b64(prev_frame),
|
||||
b64_color_img = to_b64(warped_frame),
|
||||
mask = to_b64(pred_occl),
|
||||
ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
out_image = controlnetRequest(
|
||||
b64_init_img = to_b64(out_image),
|
||||
b64_prev_img = to_b64(prev_frame),
|
||||
b64_color_img = to_b64(warped_frame),
|
||||
mask = None,
|
||||
ds=FIX_STRENGTH, w=w, h=h).sendRequest()
|
||||
|
||||
# These step is necessary to reduce color drift of the image that some models may cause
|
||||
out_image = skimage.exposure.match_histograms(out_image, init_frame, multichannel=True, channel_axis=-1)
|
||||
|
||||
output_video.write(out_image)
|
||||
if SAVE_FRAMES:
|
||||
if not os.path.isdir('out'): os.makedirs('out')
|
||||
cv2.imwrite(f'out/{ind+1:05d}.png', out_image)
|
||||
|
||||
pred_flow_img = flow_viz.flow_to_image(pred_flow)
|
||||
frames_img = cv2.hconcat(list(clip_frames))
|
||||
data_img = cv2.hconcat([pred_flow_img, pred_occl, warped_frame, out_image])
|
||||
|
||||
cv2.imshow('Out img', cv2.vconcat([frames_img, data_img]))
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
prev_frame = out_image.copy()
|
||||
|
||||
# Release the input and output video files
|
||||
output_video.release()
|
||||
|
||||
# Close all windows
|
||||
if VISUALIZE: cv2.destroyAllWindows()
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import requests
|
||||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
import h5py
|
||||
from flow_utils import compute_diff_map
|
||||
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
INPUT_VIDEO = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.mp4"
|
||||
FLOW_MAPS = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.h5"
|
||||
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
|
||||
|
||||
PROMPT = "Underwater shot Peter Gabriel with closed eyes in Peter Gabriel's music video. 80's music video. VHS style. Dramatic light, Cinematic light. RAW photo, 8k uhd, dslr, soft lighting, high quality, film grain."
|
||||
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
w,h = 1088, 576 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
|
||||
|
||||
START_FROM_IND = 0 # index of a frame to start a processing from. Might be helpful with long animations where you need to restart the script multiple times
|
||||
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
|
||||
|
||||
PROCESSING_STRENGTH = 0.95
|
||||
BLUR_FIX_STRENGTH = 0.15
|
||||
|
||||
APPLY_HED = True
|
||||
APPLY_CANNY = False
|
||||
APPLY_DEPTH = False
|
||||
GUESSMODE = False
|
||||
|
||||
CFG_SCALE = 5.5
|
||||
|
||||
VISUALIZE = True
|
||||
|
||||
def to_b64(img):
|
||||
img_cliped = np.clip(img, 0, 255).astype(np.uint8)
|
||||
_, buffer = cv2.imencode('.png', img_cliped)
|
||||
b64img = base64.b64encode(buffer).decode("utf-8")
|
||||
return b64img
|
||||
|
||||
class controlnetRequest():
|
||||
def __init__(self, b64_cur_img, b64_hed_img, ds = 0.35, w=w, h=h, mask = None, seed=-1):
|
||||
self.url = "http://localhost:7860/sdapi/v1/img2img"
|
||||
self.body = {
|
||||
"init_images": [b64_cur_img],
|
||||
"mask": mask,
|
||||
"mask_blur": 0,
|
||||
"inpainting_fill": 1,
|
||||
"inpainting_mask_invert": 0,
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": N_PROMPT,
|
||||
"seed": seed,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 15,
|
||||
"cfg_scale": CFG_SCALE,
|
||||
"denoising_strength": ds,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"restore_faces": False,
|
||||
"eta": 0,
|
||||
"sampler_index": "DPM++ 2S a",
|
||||
"control_net_enabled": True,
|
||||
"alwayson_scripts": {
|
||||
"ControlNet":{"args": []}
|
||||
},
|
||||
}
|
||||
|
||||
if APPLY_HED:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_hed_img,
|
||||
"module": "hed",
|
||||
"model": "control_hed-fp16 [13fee50b]",
|
||||
"weight": 0.65,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.65,
|
||||
"guessmode": GUESSMODE
|
||||
})
|
||||
|
||||
if APPLY_CANNY:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_hed_img,
|
||||
"module": "canny",
|
||||
"model": "control_canny-fp16 [e3fe7712]",
|
||||
"weight": 0.85,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"threshold_a": 35,
|
||||
"threshold_b": 35,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.85,
|
||||
"guessmode": GUESSMODE
|
||||
})
|
||||
|
||||
if APPLY_DEPTH:
|
||||
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
|
||||
"input_image": b64_hed_img,
|
||||
"module": "depth",
|
||||
"model": "control_depth-fp16 [400750f6]",
|
||||
"weight": 0.85,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance_start": 0,
|
||||
"guidance_end": 0.85,
|
||||
"guessmode": GUESSMODE
|
||||
})
|
||||
|
||||
|
||||
def sendRequest(self):
|
||||
# Request to web-ui
|
||||
data_js = requests.post(self.url, json=self.body).json()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
return out_image
|
||||
|
||||
|
||||
|
||||
if VISUALIZE: cv2.namedWindow('Out img')
|
||||
|
||||
# Open the input video file
|
||||
input_video = cv2.VideoCapture(INPUT_VIDEO)
|
||||
|
||||
# Get useful info from the source video
|
||||
fps = int(input_video.get(cv2.CAP_PROP_FPS))
|
||||
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
#init_image = None
|
||||
|
||||
# reading flow maps in a stream manner
|
||||
with h5py.File(FLOW_MAPS, 'r') as f:
|
||||
flow_maps = f['flow_maps']
|
||||
|
||||
for ind in tqdm(range(total_frames)):
|
||||
# Read the next frame from the input video
|
||||
if not input_video.isOpened(): break
|
||||
ret, cur_frame = input_video.read()
|
||||
if not ret: break
|
||||
|
||||
if ind+1 < START_FROM_IND: continue
|
||||
|
||||
is_keyframe = True
|
||||
if prev_frame is not None:
|
||||
# Compute absolute difference between current and previous frame
|
||||
frames_diff = cv2.absdiff(cur_frame, prev_frame)
|
||||
# Compute mean of absolute difference
|
||||
mean_diff = cv2.mean(frames_diff)[0]
|
||||
# Check if mean difference is above threshold
|
||||
is_keyframe = mean_diff > 30
|
||||
|
||||
# Generate course version of a current frame with previous stylized frame as a reference image
|
||||
if is_keyframe:
|
||||
# Resize the frame to proper resolution
|
||||
frame = cv2.resize(cur_frame, (w, h))
|
||||
|
||||
# Processing current frame with current frame as a mask without any inpainting
|
||||
out_image = controlnetRequest(to_b64(frame), to_b64(frame), PROCESSING_STRENGTH, w, h, mask = None).sendRequest()
|
||||
|
||||
alpha_img = out_image.copy()
|
||||
out_image_ = out_image.copy()
|
||||
warped_styled = out_image.copy()
|
||||
#init_image = out_image.copy()
|
||||
else:
|
||||
# Resize the frame to proper resolution
|
||||
frame = cv2.resize(cur_frame, (w, h))
|
||||
prev_frame = cv2.resize(prev_frame, (w, h))
|
||||
|
||||
# Processing current frame with current frame as a mask without any inpainting
|
||||
out_image = controlnetRequest(to_b64(frame), to_b64(frame), PROCESSING_STRENGTH, w, h, mask = None).sendRequest()
|
||||
|
||||
next_flow, prev_flow = flow_maps[ind-1].astype(np.float32)
|
||||
alpha_mask, warped_styled = compute_diff_map(next_flow, prev_flow, prev_frame, frame, prev_frame_styled)
|
||||
|
||||
# This clipping at lower side required to fix small trailing issues that for some reason left outside of the bright part of the mask,
|
||||
# and at the higher part it making parts changed strongly to do it with less flickering.
|
||||
alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
|
||||
alpha_img = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# normalizing the colors
|
||||
out_image = skimage.exposure.match_histograms(out_image, frame, multichannel=False, channel_axis=-1)
|
||||
|
||||
out_image = out_image.astype(float) * alpha_mask + warped_styled.astype(float) * (1 - alpha_mask)
|
||||
|
||||
#out_image = skimage.exposure.match_histograms(out_image, prev_frame, multichannel=True, channel_axis=-1)
|
||||
#out_image_ = (out_image * 0.65 + warped_styled * 0.35)
|
||||
|
||||
|
||||
# Bluring issue fix via additional processing
|
||||
out_image_fixed = controlnetRequest(to_b64(out_image), to_b64(frame), BLUR_FIX_STRENGTH, w, h, mask = None, seed=8888).sendRequest()
|
||||
|
||||
|
||||
# Write the frame to the output video
|
||||
frame_out = np.clip(out_image_fixed, 0, 255).astype(np.uint8)
|
||||
output_video.write(frame_out)
|
||||
|
||||
if VISUALIZE:
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
warped_styled = np.clip(warped_styled, 0, 255).astype(np.uint8)
|
||||
|
||||
img_show_top = cv2.hconcat([frame, warped_styled])
|
||||
img_show_bot = cv2.hconcat([frame_out, alpha_img])
|
||||
cv2.imshow('Out img', cv2.vconcat([img_show_top, img_show_bot]))
|
||||
cv2.setWindowTitle("Out img", str(ind+1))
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
if SAVE_FRAMES:
|
||||
if not os.path.isdir('out'): os.makedirs('out')
|
||||
cv2.imwrite(f'out/{ind+1:05d}.png', frame_out)
|
||||
|
||||
prev_frame = cur_frame.copy()
|
||||
prev_frame_styled = out_image.copy()
|
||||
|
||||
|
||||
# Release the input and output video files
|
||||
input_video.release()
|
||||
output_video.release()
|
||||
|
||||
# Close all windows
|
||||
if VISUALIZE: cv2.destroyAllWindows()
|
||||
112
readme.md
|
|
@ -1,34 +1,92 @@
|
|||
# SD-CN-Animation Script
|
||||
This script allows you to automate video stylization task using StableDiffusion and ControlNet. It uses a simple optical flow estimation algorithm to keep the animation stable and create an inpating mask that is used to generate the next frame. Here is an example of a video made with this script:
|
||||
# SD-CN-Animation
|
||||
> [!Warning]
|
||||
> This repository is no longer maintained. If you are looking for more modern ComfyUI version of this tool please check out this repository, this might be what you are actually looking for: https://github.com/pxl-pshr/ComfyUI-SD-CN-Animation
|
||||
|
||||
[](https://youtu.be/j-0niEMm6DU)
|
||||
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an occlusion mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
|
||||
|
||||
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
|
||||

|
||||
sd-cn-animation ui preview
|
||||
|
||||
## Dependencies
|
||||
To install all necessary dependencies run this command
|
||||
**In vid2vid mode do not forget to activate ControlNet model to achieve better results. Without it the resulting video might be quite choppy. Do not put any images in CN as the frames would pass automatically from the video.**
|
||||
Here are CN parameters that seem to give the best results so far:
|
||||

|
||||
|
||||
|
||||
### Video to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/girl_org.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_jc.gif" raw=true></td>
|
||||
<td><img src="examples/girl_to_wc.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">Original video</td>
|
||||
<td width=33% align="center">"Jessica Chastain"</td>
|
||||
<td width=33% align="center">"Watercolor painting"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
|
||||
|
||||
### Text to Video Examples:
|
||||
</table>
|
||||
<table class="center">
|
||||
<tr>
|
||||
<td><img src="examples/flower_1.gif" raw=true></td>
|
||||
<td><img src="examples/bonfire_1.gif" raw=true></td>
|
||||
<td><img src="examples/diamond_4.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of a flower"</td>
|
||||
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
|
||||
<td width=33% align="center">"close up of a diamond laying on the table"</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="examples/macaroni_1.gif" raw=true></td>
|
||||
<td><img src="examples/gold_1.gif" raw=true></td>
|
||||
<td><img src="examples/tree_2.gif" raw=true></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width=33% align="center">"close up of macaroni on the plate"</td>
|
||||
<td width=33% align="center">"close up of golden sphere"</td>
|
||||
<td width=33% align="center">"a tree standing in the winter forest"</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
|
||||
|
||||
## Installing the extension
|
||||
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
|
||||
|
||||
## Known issues
|
||||
* If you see error like this ```IndexError: list index out of range``` try to restart webui, it should fix it. If the issue still prevelent try to uninstall and reinstall scikit-image==0.19.2 with no --no-cache-dir flag like this.
|
||||
```
|
||||
pip install opencv-python opencv-contrib-python numpy tqdm
|
||||
pip uninstall scikit-image
|
||||
pip install scikit-image==0.19.2 --no-cache-dir
|
||||
```
|
||||
* The extension might work incorrectly if 'Apply color correction to img2img results to match original colors.' option is enabled. Make sure to disable it in 'Settings' tab -> 'Stable Diffusion' section.
|
||||
* If you have an error like 'Need to enable queue to use generators.', please update webui to the latest version. Beware that only [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is fully supported.
|
||||
* The extension is not compatible with Macs. If you have a case that extension is working for you or do you know how to make it compatible, please open a new discussion.
|
||||
|
||||
To run the algorithm alongside Stable Diffusion with ControlNet in 640x640 resolution would require about 8GB of VRAM, as [RAFT](https://github.com/princeton-vl/RAFT) (current optical flow estimation method) takes about 3,7GB of memory.
|
||||
## Last version changes: v0.9
|
||||
* Fixed issues #69, #76, #91, #92.
|
||||
* Fixed an issue in vid2vid mode when an occlusion mask computed from the optical flow may include unnecessary parts (where flow is non-zero).
|
||||
* Added 'Extra params' in vid2vid mode for more fine-grain controls of the processing pipeline.
|
||||
* Better default parameters set for vid2vid pipeline.
|
||||
* In txt2vid mode after the first frame is generated the seed is now automatically set to -1 to prevent blurring issues.
|
||||
* Added an option to save resulting frames into a folder alongside the video.
|
||||
* Added ability to export current parameters in a human readable form as a json.
|
||||
* Interpolation mode in the flow-applying stage is set to ‘nearest’ to reduce overtime image blurring.
|
||||
* Added ControlNet to txt2vid mode as well as fixing #86 issue, thanks to [@mariaWitch](https://github.com/mariaWitch)
|
||||
* Fixed a major issue when ConrtolNet used wrong input images. Because of this vid2vid results were way worse than they should be.
|
||||
* Text to video mode now supports video as a guidance for ControlNet. It allows to create much stronger video stylizations.
|
||||
|
||||
## Running the script
|
||||
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You also should have [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly do the following:
|
||||
1. Go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings"
|
||||
2. Run web-ui with '--api' flag. It also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
|
||||
```bash webui.sh --xformers --api```
|
||||
3. Go to the script.py file and change main parameters (INPUT_VIDEO, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time.
|
||||
4. Run the script with ```python3 script.py```
|
||||
|
||||
## Last version changes 0.3
|
||||
* Flow estimation algorithm is updated to [RAFT](https://github.com/princeton-vl/RAFT) method.
|
||||
* Difference map now computed as per-pixel maximum of warped first and second frame of the original video and occlusion map that is computed from forward and backward flow estimation.
|
||||
* Added keyframe detection that illuminates ghosting artifacts between the scenes.
|
||||
|
||||
## Potential improvements
|
||||
There are several ways overall quality of animation may be improved:
|
||||
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
|
||||
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
|
||||
* The quality of flow estimation might be greatly improved with proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
|
||||
* It is possible to lower VRAM requirements if precompute flows maps beforehand.
|
||||
<!--
|
||||
* ControlNet with preprocessers like "reference_only", "reference_adain", "reference_adain+attn" are not reseted with video frames to have an ability to control style of the video.
|
||||
* Fixed an issue because of witch 'processing_strength' UI parameters does not actually affected denoising strength at the fist processing step.
|
||||
* Fixed issue #112. It will not try to reinstall requirements at every start of webui.
|
||||
* Some improvements in text 2 video method.
|
||||
* Parameters used to generated a video now automatically saved in video's folder.
|
||||
* Added ability to control what frame will be send to CN in text to video mode.
|
||||
-->
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
scikit-image
|
||||
284
script.py
|
|
@ -1,284 +0,0 @@
|
|||
import requests
|
||||
import cv2
|
||||
import base64
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
#RAFT dependencies
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder
|
||||
|
||||
|
||||
INPUT_VIDEO = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Fallout_noir_2/benny.mp4"
|
||||
OUTPUT_VIDEO = "result.mp4"
|
||||
|
||||
PROMPT = "RAW photo, Matthew Perry wearing suit and pants in the wasteland, cinematic light, dramatic light, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
||||
N_PROMPT = "blur, blurred, unfocus, obscure, dim, fade, obscure, muddy, black and white image, old, naked, black person, green face, green skin, black and white, slanted eyes, red eyes, blood eyes, deformed, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, ((((mutated hands and fingers)))), watermark, watermarked, oversaturated, censored, distorted hands, amputation, missing hands, obese, doubled face, double hands"
|
||||
SEED = -1
|
||||
w,h = 512, 704 # Width and height of the processed image. Note that actual image processed would be a W x 2H resolution. You should have enough VRAM to process it.
|
||||
|
||||
START_FROM_IND = 0 # index of a frame to start a processing from. Might be helpful with long animations where you need to restart the script multiple times
|
||||
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
|
||||
|
||||
BLUR_SIZE = (15, 15)
|
||||
BLUR_SIGMA = 12
|
||||
|
||||
def to_b64(img):
|
||||
_, buffer = cv2.imencode('.png', img)
|
||||
b64img = base64.b64encode(buffer).decode("utf-8")
|
||||
return b64img
|
||||
|
||||
class controlnetRequest():
|
||||
def __init__(self, b64_cur_img, b64_hed_img, ds = 0.35, w=w, h=h, mask = None):
|
||||
self.url = "http://localhost:7860/sdapi/v1/img2img"
|
||||
self.body = {
|
||||
"init_images": [b64_cur_img],
|
||||
"mask": mask,
|
||||
"mask_blur": 0,
|
||||
"inpainting_fill": 1,
|
||||
"inpainting_mask_invert": 0,
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": N_PROMPT,
|
||||
"seed": SEED,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 15,
|
||||
"cfg_scale": 7,
|
||||
"denoising_strength": ds,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"restore_faces": False,
|
||||
"eta": 0,
|
||||
"sampler_index": "DPM++ 2S a",
|
||||
"control_net_enabled": True,
|
||||
"alwayson_scripts": {
|
||||
"ControlNet":{
|
||||
"args": [
|
||||
{
|
||||
"input_image": b64_hed_img,
|
||||
"module": "hed",
|
||||
"model": "control_hed-fp16 [13fee50b]",
|
||||
"weight": 1,
|
||||
"resize_mode": "Just Resize",
|
||||
"lowvram": False,
|
||||
"processor_res": 512,
|
||||
"guidance": 1,
|
||||
"guessmode": False
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def sendRequest(self):
|
||||
r = requests.post(self.url, json=self.body)
|
||||
return r.json()
|
||||
|
||||
DEVICE = 'cuda'
|
||||
RAFT_model = None
|
||||
def RAFT_estimate_flow_diff(frame1, frame2, frame1_styled):
|
||||
global RAFT_model
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': 'RAFT/models/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(DEVICE)
|
||||
RAFT_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(DEVICE)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(DEVICE)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate and apply optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1,2,0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
flow_map = prev_flow.copy()
|
||||
h, w = flow_map.shape[:2]
|
||||
flow_map[:,:,0] += np.arange(w)
|
||||
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(frame1, flow_map, None, cv2.INTER_LINEAR)
|
||||
warped_frame_styled = cv2.remap(frame1_styled, flow_map, None, cv2.INTER_LINEAR)
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask = np.abs(warped_frame.astype(np.float32) - frame2.astype(np.float32)) / 255
|
||||
diff_mask = diff_mask.max(axis = -1, keepdims=True)
|
||||
|
||||
diff = np.maximum(occlusion_mask * 0.2, diff_mask * 4).repeat(3, axis = -1)
|
||||
#diff = diff * 1.5
|
||||
|
||||
diff_blured = cv2.GaussianBlur(diff, BLUR_SIZE, BLUR_SIGMA, cv2.BORDER_DEFAULT)
|
||||
diff_frame = np.clip((diff + diff_blured) * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
return warped_frame, diff_frame, warped_frame_styled
|
||||
|
||||
''' # old flow estimation algorithm, might be useful later
|
||||
def estimate_flow_diff(frame1, frame2, frame1_styled):
|
||||
prvs = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
|
||||
next = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# estimate and apply optical flow
|
||||
flow = cv2.calcOpticalFlowFarneback(prvs, next, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
||||
h, w = flow.shape[:2]
|
||||
flow_data = -flow
|
||||
flow_data[:,:,0] += np.arange(w)
|
||||
flow_data[:,:,1] += np.arange(h)[:,np.newaxis]
|
||||
#map_x, map_y = cv2.convertMaps(flow_data, 0, -1, True)
|
||||
warped_frame = cv2.remap(frame1, flow_data, None, cv2.INTER_LINEAR)
|
||||
warped_frame_styled = cv2.remap(frame1_styled, flow_data, None, cv2.INTER_LINEAR)
|
||||
|
||||
# compute occlusion mask
|
||||
flow_back = cv2.calcOpticalFlowFarneback(next, prvs, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
||||
fb_flow = flow + flow_back
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
occlusion_mask = fb_norm[..., None]
|
||||
|
||||
diff_mask = np.abs(warped_frame.astype(np.float32) - frame2.astype(np.float32)) / 255
|
||||
diff_mask = diff_mask.max(axis = -1, keepdims=True)
|
||||
|
||||
diff = np.maximum(occlusion_mask, diff_mask).repeat(3, axis = -1)
|
||||
#diff = diff * 1.5
|
||||
|
||||
#diff = cv2.GaussianBlur(diff, BLUR_SIZE, BLUR_SIGMA, cv2.BORDER_DEFAULT)
|
||||
diff_frame = np.clip(diff * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
return warped_frame, diff_frame, warped_frame_styled
|
||||
'''
|
||||
|
||||
cv2.namedWindow('Out img')
|
||||
|
||||
# Open the input video file
|
||||
input_video = cv2.VideoCapture(INPUT_VIDEO)
|
||||
|
||||
# Get useful info from the souce video
|
||||
fps = int(input_video.get(cv2.CAP_PROP_FPS))
|
||||
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'MP4V'), fps, (w, h))
|
||||
|
||||
prev_frame = None
|
||||
|
||||
for ind in tqdm(range(total_frames)):
|
||||
# Read the next frame from the input video
|
||||
if not input_video.isOpened(): break
|
||||
ret, cur_frame = input_video.read()
|
||||
if not ret: break
|
||||
|
||||
if ind+1 < START_FROM_IND: continue
|
||||
|
||||
is_keyframe = True
|
||||
if prev_frame is not None:
|
||||
# Compute absolute difference between current and previous frame
|
||||
frames_diff = cv2.absdiff(cur_frame, prev_frame)
|
||||
# Compute mean of absolute difference
|
||||
mean_diff = cv2.mean(frames_diff)[0]
|
||||
# Check if mean difference is above threshold
|
||||
is_keyframe = mean_diff > 30
|
||||
|
||||
# Generate course version of a current frame with previous stylized frame as a reference image
|
||||
if is_keyframe:
|
||||
# Resize the frame to proper resolution
|
||||
frame = cv2.resize(cur_frame, (w, h))
|
||||
|
||||
# Sending request to the web-ui
|
||||
data_js = controlnetRequest(to_b64(frame), to_b64(frame), 0.65, w, h, mask = None).sendRequest()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
diff_mask = out_image.copy()
|
||||
diff_mask_blured = out_image.copy()
|
||||
else:
|
||||
# Resize the frame to proper resolution
|
||||
frame = cv2.resize(cur_frame, (w, h))
|
||||
prev_frame = cv2.resize(prev_frame, (w, h))
|
||||
|
||||
# Sending request to the web-ui
|
||||
data_js = controlnetRequest(to_b64(frame), to_b64(frame), 0.35, w, h, mask = None).sendRequest()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
|
||||
|
||||
_, diff_mask, warped_styled = RAFT_estimate_flow_diff(prev_frame, frame, prev_frame_styled)
|
||||
|
||||
alpha = diff_mask.astype(np.float32) / 255.0
|
||||
pr_image = out_image * alpha + warped_styled * (1 - alpha)
|
||||
|
||||
diff_mask_blured = cv2.GaussianBlur(alpha * 255, BLUR_SIZE, BLUR_SIGMA, cv2.BORDER_DEFAULT)
|
||||
diff_mask_blured = np.clip(diff_mask_blured + diff_mask, 5, 255).astype(np.uint8)
|
||||
|
||||
# Sending request to the web-ui
|
||||
data_js = controlnetRequest(to_b64(pr_image), to_b64(frame), 0.65, w, h, mask = to_b64(diff_mask_blured)).sendRequest()
|
||||
|
||||
# Convert the byte array to a NumPy array
|
||||
image_bytes = base64.b64decode(data_js["images"][0])
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# Convert the NumPy array to a cv2 image
|
||||
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
|
||||
# Write the frame to the output video
|
||||
frame_out = out_image[:h]
|
||||
output_video.write(frame_out)
|
||||
|
||||
# show the last written frame - useful to catch any issue with the process
|
||||
img_show = cv2.hconcat([out_image, diff_mask, diff_mask_blured])
|
||||
cv2.imshow('Out img', img_show)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
|
||||
|
||||
|
||||
# Write the frame to the output video
|
||||
output_video.write(frame_out)
|
||||
prev_frame = cur_frame.copy()
|
||||
prev_frame_styled = frame_out.copy()
|
||||
|
||||
|
||||
if SAVE_FRAMES:
|
||||
if not os.path.isdir('out'): os.makedirs('out')
|
||||
cv2.imwrite(f'out/{ind+1:05d}.png', frame_out)
|
||||
|
||||
# Release the input and output video files
|
||||
input_video.release()
|
||||
output_video.release()
|
||||
|
||||
# Close all windows
|
||||
cv2.destroyAllWindows()
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
import sys, os
|
||||
|
||||
import gradio as gr
|
||||
import modules
|
||||
from types import SimpleNamespace
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
from modules.ui_components import ToolButton, FormRow, FormGroup
|
||||
from modules.ui import create_override_settings_dropdown
|
||||
import modules.scripts as scripts
|
||||
|
||||
from modules.sd_samplers import samplers_for_img2img
|
||||
|
||||
from scripts.core import vid2vid, txt2vid, utils
|
||||
import traceback
|
||||
|
||||
def V2VArgs():
|
||||
seed = -1
|
||||
width = 1024
|
||||
height = 576
|
||||
cfg_scale = 5.5
|
||||
steps = 15
|
||||
prompt = ""
|
||||
n_prompt = "text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
processing_strength = 0.85
|
||||
fix_frame_strength = 0.15
|
||||
return locals()
|
||||
|
||||
def T2VArgs():
|
||||
seed = -1
|
||||
width = 768
|
||||
height = 512
|
||||
cfg_scale = 5.5
|
||||
steps = 15
|
||||
prompt = ""
|
||||
n_prompt = "((blur, blurr, blurred, blurry, fuzzy, unclear, unfocus, bocca effect)), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
processing_strength = 0.75
|
||||
fix_frame_strength = 0.35
|
||||
return locals()
|
||||
|
||||
def setup_common_values(mode, d):
|
||||
with gr.Row():
|
||||
width = gr.Slider(label='Width', minimum=64, maximum=2048, step=64, value=d.width, interactive=True)
|
||||
height = gr.Slider(label='Height', minimum=64, maximum=2048, step=64, value=d.height, interactive=True)
|
||||
with gr.Row(elem_id=f'{mode}_prompt_toprow'):
|
||||
prompt = gr.Textbox(label='Prompt', lines=3, interactive=True, elem_id=f"{mode}_prompt", placeholder="Enter your prompt here...")
|
||||
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
|
||||
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
|
||||
with gr.Row():
|
||||
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
|
||||
with gr.Row():
|
||||
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
|
||||
with gr.Row():
|
||||
processing_strength = gr.Slider(label="Processing strength (Step 1)", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||
fix_frame_strength = gr.Slider(label="Fix frame strength (Step 2)", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
|
||||
with gr.Row():
|
||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{mode}_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index", interactive=True)
|
||||
steps = gr.Slider(label="Sampling steps", minimum=1, maximum=150, step=1, elem_id=f"{mode}_steps", value=d.steps, interactive=True)
|
||||
|
||||
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength, sampler_index, steps
|
||||
|
||||
def inputs_ui():
|
||||
v2v_args = SimpleNamespace(**V2VArgs())
|
||||
t2v_args = SimpleNamespace(**T2VArgs())
|
||||
with gr.Tabs():
|
||||
glo_sdcn_process_mode = gr.State(value='vid2vid')
|
||||
|
||||
with gr.Tab('vid2vid') as tab_vid2vid:
|
||||
with gr.Row():
|
||||
gr.HTML('Input video (each frame will be used as initial image for SD and as input image to CN): *REQUIRED')
|
||||
with gr.Row():
|
||||
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
|
||||
|
||||
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength, v2v_sampler_index, v2v_steps = setup_common_values('vid2vid', v2v_args)
|
||||
|
||||
with gr.Accordion("Extra settings",open=False):
|
||||
gr.HTML('# Occlusion mask params:')
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
v2v_occlusion_mask_blur = gr.Slider(label='Occlusion blur strength', minimum=0, maximum=10, step=0.1, value=3, interactive=True)
|
||||
gr.HTML('')
|
||||
v2v_occlusion_mask_trailing = gr.Checkbox(label="Occlusion trailing", info="Reduce ghosting but adds more flickering to the video", value=True, interactive=True)
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
v2v_occlusion_mask_flow_multiplier = gr.Slider(label='Occlusion flow multiplier', minimum=0, maximum=10, step=0.1, value=5, interactive=True)
|
||||
v2v_occlusion_mask_difo_multiplier = gr.Slider(label='Occlusion diff origin multiplier', minimum=0, maximum=10, step=0.1, value=2, interactive=True)
|
||||
v2v_occlusion_mask_difs_multiplier = gr.Slider(label='Occlusion diff styled multiplier', minimum=0, maximum=10, step=0.1, value=0, interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
gr.HTML('# Step 1 params:')
|
||||
v2v_step_1_seed = gr.Number(label='Seed', value = -1, Interactive = True, precision=0)
|
||||
gr.HTML('<br>')
|
||||
v2v_step_1_blend_alpha = gr.Slider(label='Warped prev frame vs Current frame blend alpha', minimum=0, maximum=1, step=0.1, value=1, interactive=True)
|
||||
v2v_step_1_processing_mode = gr.Radio(["Process full image then blend in occlusions", "Inpaint occlusions"], type="index", \
|
||||
label="Processing mode", value="Process full image then blend in occlusions", interactive=True)
|
||||
|
||||
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
gr.HTML('# Step 2 params:')
|
||||
v2v_step_2_seed = gr.Number(label='Seed', value = 8888, Interactive = True, precision=0)
|
||||
|
||||
with FormRow(elem_id="vid2vid_override_settings_row") as row:
|
||||
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
|
||||
|
||||
with FormGroup(elem_id=f"script_container"):
|
||||
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
|
||||
|
||||
with gr.Tab('txt2vid') as tab_txt2vid:
|
||||
with gr.Row():
|
||||
gr.HTML('Control video (each frame will be used as input image to CN): *NOT REQUIRED')
|
||||
with gr.Row():
|
||||
t2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="tex_to_vid_chosen_file")
|
||||
t2v_init_image = gr.Image(label="Input image", interactive=True, file_count="single", file_types=["image"], elem_id="tex_to_vid_init_image")
|
||||
|
||||
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args)
|
||||
|
||||
with gr.Row():
|
||||
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
|
||||
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
|
||||
|
||||
gr.HTML('<br>')
|
||||
t2v_cn_frame_send = gr.Radio(["None", "Current generated frame", "Previous generated frame", "Current reference video frame"], type="index", \
|
||||
label="What frame should be send to CN?", value="None", interactive=True)
|
||||
|
||||
with FormRow(elem_id="txt2vid_override_settings_row") as row:
|
||||
t2v_override_settings = create_override_settings_dropdown("txt2vid", row)
|
||||
|
||||
with FormGroup(elem_id=f"script_container"):
|
||||
t2v_custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
tab_vid2vid.select(fn=lambda: 'vid2vid', inputs=[], outputs=[glo_sdcn_process_mode])
|
||||
tab_txt2vid.select(fn=lambda: 'txt2vid', inputs=[], outputs=[glo_sdcn_process_mode])
|
||||
|
||||
return locals()
|
||||
|
||||
def process(*args):
|
||||
msg = 'Done'
|
||||
try:
|
||||
if args[0] == 'vid2vid':
|
||||
yield from vid2vid.start_process(*args)
|
||||
elif args[0] == 'txt2vid':
|
||||
yield from txt2vid.start_process(*args)
|
||||
else:
|
||||
msg = f"Unsupported processing mode: '{args[0]}'"
|
||||
raise Exception(msg)
|
||||
except Exception as error:
|
||||
# handle the exception
|
||||
msg = f"An exception occurred while trying to process the frame: {error}"
|
||||
print(msg)
|
||||
traceback.print_exc()
|
||||
|
||||
yield msg, gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Video.update(), gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||
|
||||
def stop_process(*args):
|
||||
utils.shared.is_interrupted = True
|
||||
return gr.Button.update(interactive=False)
|
||||
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as sdcnanim_interface:
|
||||
components = {}
|
||||
|
||||
#dv = SimpleNamespace(**T2VOutputArgs())
|
||||
with gr.Row(elem_id='sdcn-core', equal_height=False, variant='compact'):
|
||||
with gr.Column(scale=1, variant='panel'):
|
||||
#with gr.Tabs():
|
||||
components = inputs_ui()
|
||||
|
||||
with gr.Accordion("Export settings", open=False):
|
||||
export_settings_button = gr.Button('Export', elem_id=f"sdcn_export_settings_button")
|
||||
export_setting_json = gr.Code(value='')
|
||||
|
||||
|
||||
with gr.Column(scale=1, variant='compact'):
|
||||
with gr.Row(variant='compact'):
|
||||
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
|
||||
stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False)
|
||||
|
||||
save_frames_check = gr.Checkbox(label="Save frames into a folder nearby a video (check it before running the generation if you also want to save frames separately)", value=True, interactive=True)
|
||||
gr.HTML('<br>')
|
||||
|
||||
with gr.Column(variant="panel"):
|
||||
sp_progress = gr.HTML(elem_id="sp_progress", value="")
|
||||
|
||||
with gr.Row(variant='compact'):
|
||||
img_preview_curr_frame = gr.Image(label='Current frame', elem_id=f"img_preview_curr_frame", type='pil', height=240)
|
||||
img_preview_curr_occl = gr.Image(label='Current occlusion', elem_id=f"img_preview_curr_occl", type='pil', height=240)
|
||||
with gr.Row(variant='compact'):
|
||||
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil', height=240)
|
||||
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil', height=240)
|
||||
|
||||
video_preview = gr.Video(interactive=False)
|
||||
|
||||
with gr.Row(variant='compact'):
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
||||
components['glo_save_frames_check'] = save_frames_check
|
||||
|
||||
# Define parameters for the action methods.
|
||||
utils.shared.v2v_custom_inputs_size = len(components['v2v_custom_inputs'])
|
||||
utils.shared.t2v_custom_inputs_size = len(components['t2v_custom_inputs'])
|
||||
#print('v2v_custom_inputs', len(components['v2v_custom_inputs']), components['v2v_custom_inputs'])
|
||||
#print('t2v_custom_inputs', len(components['t2v_custom_inputs']), components['t2v_custom_inputs'])
|
||||
method_inputs = [components[name] for name in utils.get_component_names()] + components['v2v_custom_inputs'] + components['t2v_custom_inputs']
|
||||
|
||||
method_outputs = [
|
||||
sp_progress,
|
||||
img_preview_curr_frame,
|
||||
img_preview_curr_occl,
|
||||
img_preview_prev_warp,
|
||||
img_preview_processed,
|
||||
video_preview,
|
||||
run_button,
|
||||
stop_button,
|
||||
]
|
||||
|
||||
run_button.click(
|
||||
fn=process, #wrap_gradio_gpu_call(start_process, extra_outputs=[None, '', '']),
|
||||
inputs=method_inputs,
|
||||
outputs=method_outputs,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
stop_button.click(
|
||||
fn=stop_process,
|
||||
outputs=[stop_button],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
export_settings_button.click(
|
||||
fn=utils.export_settings,
|
||||
inputs=method_inputs,
|
||||
outputs=[export_setting_json],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
|
||||
# define queue - required for generators
|
||||
sdcnanim_interface.queue()
|
||||
|
||||
return [(sdcnanim_interface, "SD-CN-Animation", "sd_cn_animation_interface")]
|
||||
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
import sys, os
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import argparse
|
||||
from RAFT.raft import RAFT
|
||||
from RAFT.utils.utils import InputPadder
|
||||
|
||||
import modules.paths as ph
|
||||
import gc
|
||||
|
||||
RAFT_model = None
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
|
||||
|
||||
def background_subtractor(frame, fgbg):
|
||||
fgmask = fgbg.apply(frame)
|
||||
return cv2.bitwise_and(frame, frame, mask=fgmask)
|
||||
|
||||
def RAFT_clear_memory():
|
||||
global RAFT_model
|
||||
del RAFT_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
RAFT_model = None
|
||||
|
||||
def RAFT_estimate_flow(frame1, frame2, device='cuda'):
|
||||
global RAFT_model
|
||||
|
||||
org_size = frame1.shape[1], frame1.shape[0]
|
||||
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
|
||||
frame1 = cv2.resize(frame1, size)
|
||||
frame2 = cv2.resize(frame2, size)
|
||||
|
||||
model_path = ph.models_path + '/RAFT/raft-things.pth'
|
||||
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
load_file_from_url(remote_model_path, file_name=model_path)
|
||||
|
||||
if RAFT_model is None:
|
||||
args = argparse.Namespace(**{
|
||||
'model': ph.models_path + '/RAFT/raft-things.pth',
|
||||
'mixed_precision': True,
|
||||
'small': False,
|
||||
'alternate_corr': False,
|
||||
'path': ""
|
||||
})
|
||||
|
||||
RAFT_model = torch.nn.DataParallel(RAFT(args))
|
||||
RAFT_model.load_state_dict(torch.load(args.model))
|
||||
|
||||
RAFT_model = RAFT_model.module
|
||||
RAFT_model.to(device)
|
||||
RAFT_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
|
||||
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
|
||||
|
||||
padder = InputPadder(frame1_torch.shape)
|
||||
image1, image2 = padder.pad(frame1_torch, frame2_torch)
|
||||
|
||||
# estimate optical flow
|
||||
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
|
||||
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
|
||||
|
||||
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow, axis=2)
|
||||
|
||||
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
|
||||
|
||||
next_flow = cv2.resize(next_flow, org_size)
|
||||
prev_flow = cv2.resize(prev_flow, org_size)
|
||||
|
||||
return next_flow, prev_flow, occlusion_mask
|
||||
|
||||
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, args_dict):
|
||||
h, w = cur_frame.shape[:2]
|
||||
fl_w, fl_h = next_flow.shape[:2]
|
||||
|
||||
# normalize flow
|
||||
next_flow = next_flow / np.array([fl_h,fl_w])
|
||||
prev_flow = prev_flow / np.array([fl_h,fl_w])
|
||||
|
||||
# compute occlusion mask
|
||||
fb_flow = next_flow + prev_flow
|
||||
fb_norm = np.linalg.norm(fb_flow , axis=2)
|
||||
|
||||
zero_flow_mask = np.clip(1 - np.linalg.norm(prev_flow, axis=-1)[...,None] * 20, 0, 1)
|
||||
diff_mask_flow = fb_norm[..., None] * zero_flow_mask
|
||||
|
||||
# resize flow
|
||||
next_flow = cv2.resize(next_flow, (w, h))
|
||||
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
|
||||
prev_flow = cv2.resize(prev_flow, (w, h))
|
||||
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
|
||||
|
||||
# Generate sampling grids
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
|
||||
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
|
||||
flow_grid = flow_grid.unsqueeze(0)
|
||||
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
|
||||
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
|
||||
flow_grid = flow_grid.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
|
||||
|
||||
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
|
||||
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
|
||||
|
||||
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
|
||||
|
||||
|
||||
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
|
||||
|
||||
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
|
||||
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
|
||||
|
||||
alpha_mask = np.maximum.reduce([diff_mask_flow * args_dict['occlusion_mask_flow_multiplier'] * 10, \
|
||||
diff_mask_org * args_dict['occlusion_mask_difo_multiplier'], \
|
||||
diff_mask_stl * args_dict['occlusion_mask_difs_multiplier']]) #
|
||||
alpha_mask = alpha_mask.repeat(3, axis = -1)
|
||||
|
||||
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
|
||||
if args_dict['occlusion_mask_blur'] > 0:
|
||||
blur_filter_size = min(w,h) // 15 | 1
|
||||
alpha_mask = cv2.GaussianBlur(alpha_mask, (blur_filter_size, blur_filter_size) , args_dict['occlusion_mask_blur'], cv2.BORDER_REFLECT)
|
||||
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
|
||||
return alpha_mask, warped_frame_styled
|
||||
|
||||
def frames_norm(frame): return frame / 127.5 - 1
|
||||
|
||||
def flow_norm(flow): return flow / 255
|
||||
|
||||
def occl_norm(occl): return occl / 127.5 - 1
|
||||
|
||||
def frames_renorm(frame): return (frame + 1) * 127.5
|
||||
|
||||
def flow_renorm(flow): return flow * 255
|
||||
|
||||
def occl_renorm(occl): return (occl + 1) * 127.5
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
import sys, os
|
||||
|
||||
import torch
|
||||
import gc
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import modules.paths as ph
|
||||
from modules import devices
|
||||
|
||||
from scripts.core import utils, flow_utils
|
||||
from FloweR.model import FloweR
|
||||
|
||||
import skimage
|
||||
import datetime
|
||||
import cv2
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
FloweR_model = None
|
||||
DEVICE = 'cpu'
|
||||
def FloweR_clear_memory():
|
||||
global FloweR_model
|
||||
del FloweR_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
FloweR_model = None
|
||||
|
||||
def FloweR_load_model(w, h):
|
||||
global DEVICE, FloweR_model
|
||||
DEVICE = devices.get_optimal_device()
|
||||
|
||||
model_path = ph.models_path + '/FloweR/FloweR_0.1.2.pth'
|
||||
#remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv' #FloweR_0.1.1.pth
|
||||
remote_model_path = 'https://drive.google.com/uc?id=1-UYsTXkdUkHLgtPK1Y5_7kKzCgzL_Z6o' #FloweR_0.1.2.pth
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
load_file_from_url(remote_model_path, file_name=model_path)
|
||||
|
||||
|
||||
FloweR_model = FloweR(input_size = (h, w))
|
||||
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
||||
# Move the model to the device
|
||||
FloweR_model = FloweR_model.to(DEVICE)
|
||||
FloweR_model.eval()
|
||||
|
||||
def read_frame_from_video(input_video):
|
||||
if input_video is None: return None
|
||||
|
||||
# Reading video file
|
||||
if input_video.isOpened():
|
||||
ret, cur_frame = input_video.read()
|
||||
if cur_frame is not None:
|
||||
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
cur_frame = None
|
||||
input_video.release()
|
||||
input_video = None
|
||||
|
||||
return cur_frame
|
||||
|
||||
def start_process(*args):
|
||||
processing_start_time = time.time()
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||
|
||||
# Open the input video file
|
||||
input_video = None
|
||||
if args_dict['file'] is not None:
|
||||
input_video = cv2.VideoCapture(args_dict['file'].name)
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video_name = f'outputs/sd-cn-animation/txt2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
||||
output_video_folder = os.path.splitext(output_video_name)[0]
|
||||
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
|
||||
|
||||
#if args_dict['save_frames_check']:
|
||||
os.makedirs(output_video_folder, exist_ok=True)
|
||||
|
||||
# Writing to current params to params.json
|
||||
setts_json = utils.export_settings(*args)
|
||||
with open(os.path.join(output_video_folder, "params.json"), "w") as outfile:
|
||||
outfile.write(setts_json)
|
||||
|
||||
curr_frame = None
|
||||
prev_frame = None
|
||||
|
||||
def save_result_to_image(image, ind):
|
||||
if args_dict['save_frames_check']:
|
||||
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
||||
|
||||
def set_cn_frame_input():
|
||||
if args_dict['cn_frame_send'] == 0: # Current generated frame"
|
||||
pass
|
||||
elif args_dict['cn_frame_send'] == 1: # Current generated frame"
|
||||
if curr_frame is not None:
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame), set_references=True)
|
||||
elif args_dict['cn_frame_send'] == 2: # Previous generated frame
|
||||
if prev_frame is not None:
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(prev_frame), set_references=True)
|
||||
elif args_dict['cn_frame_send'] == 3: # Current reference video frame
|
||||
if input_video is not None:
|
||||
curr_video_frame = read_frame_from_video(input_video)
|
||||
curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height']))
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame), set_references=True)
|
||||
else:
|
||||
raise Exception('There is no input video! Set it up first.')
|
||||
else:
|
||||
raise Exception('Incorrect cn_frame_send mode!')
|
||||
|
||||
set_cn_frame_input()
|
||||
|
||||
if args_dict['init_image'] is not None:
|
||||
#resize array to args_dict['width'], args_dict['height']
|
||||
image_array=args_dict['init_image']#this is a numpy array
|
||||
init_frame = np.array(Image.fromarray(image_array).resize((args_dict['width'], args_dict['height'])).convert('RGB'))
|
||||
processed_frame = init_frame.copy()
|
||||
else:
|
||||
processed_frames, _, _, _ = utils.txt2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
init_frame = processed_frame.copy()
|
||||
|
||||
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
|
||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
stat = f"Frame: 1 / {args_dict['length']}; " + utils.get_time_left(1, args_dict['length'], processing_start_time)
|
||||
utils.shared.is_interrupted = False
|
||||
|
||||
save_result_to_image(processed_frame, 1)
|
||||
yield stat, init_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
org_size = args_dict['width'], args_dict['height']
|
||||
size = args_dict['width'] // 128 * 128, args_dict['height'] // 128 * 128
|
||||
FloweR_load_model(size[0], size[1])
|
||||
|
||||
clip_frames = np.zeros((4, size[1], size[0], 3), dtype=np.uint8)
|
||||
|
||||
prev_frame = init_frame
|
||||
|
||||
for ind in range(args_dict['length'] - 1):
|
||||
if utils.shared.is_interrupted: break
|
||||
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('t2v', args_dict)
|
||||
|
||||
clip_frames = np.roll(clip_frames, -1, axis=0)
|
||||
clip_frames[-1] = cv2.resize(prev_frame[...,:3], size)
|
||||
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
|
||||
|
||||
with torch.no_grad():
|
||||
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
|
||||
|
||||
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
|
||||
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
|
||||
pred_next = flow_utils.frames_renorm(pred_data[...,3:6]).cpu().numpy()
|
||||
|
||||
pred_occl = np.clip(pred_occl * 10, 0, 255).astype(np.uint8)
|
||||
pred_next = np.clip(pred_next, 0, 255).astype(np.uint8)
|
||||
|
||||
pred_flow = cv2.resize(pred_flow, org_size)
|
||||
pred_occl = cv2.resize(pred_occl, org_size)
|
||||
pred_next = cv2.resize(pred_next, org_size)
|
||||
|
||||
curr_frame = pred_next.copy()
|
||||
|
||||
'''
|
||||
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
|
||||
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
|
||||
|
||||
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
|
||||
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
|
||||
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
|
||||
|
||||
flow_map = pred_flow.copy()
|
||||
flow_map[:,:,0] += np.arange(args_dict['width'])
|
||||
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
|
||||
|
||||
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT_101)
|
||||
alpha_mask = pred_occl / 255.
|
||||
#alpha_mask = np.clip(alpha_mask + np.random.normal(0, 0.4, size = alpha_mask.shape), 0, 1)
|
||||
curr_frame = pred_next.astype(float) * alpha_mask + warped_frame.astype(float) * (1 - alpha_mask)
|
||||
curr_frame = np.clip(curr_frame, 0, 255).astype(np.uint8)
|
||||
#curr_frame = warped_frame.copy()
|
||||
'''
|
||||
|
||||
set_cn_frame_input()
|
||||
|
||||
args_dict['mode'] = 4
|
||||
args_dict['init_img'] = Image.fromarray(pred_next)
|
||||
args_dict['mask_img'] = Image.fromarray(pred_occl)
|
||||
args_dict['seed'] = -1
|
||||
args_dict['denoising_strength'] = args_dict['processing_strength']
|
||||
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
#else:
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
|
||||
args_dict['mode'] = 0
|
||||
args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||
args_dict['mask_img'] = None
|
||||
args_dict['seed'] = -1
|
||||
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||
|
||||
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
#if input_video is not None:
|
||||
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
|
||||
#else:
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
|
||||
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
||||
prev_frame = processed_frame.copy()
|
||||
|
||||
save_result_to_image(processed_frame, ind + 2)
|
||||
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
|
||||
yield stat, curr_frame, pred_occl, pred_next, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
if input_video is not None: input_video.release()
|
||||
output_video.release()
|
||||
FloweR_clear_memory()
|
||||
|
||||
curr_frame = gr.Image.update()
|
||||
occlusion_mask = gr.Image.update()
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
# print('TOTAL TIME:', int(time.time() - processing_start_time))
|
||||
|
||||
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
class shared:
|
||||
is_interrupted = False
|
||||
v2v_custom_inputs_size = 0
|
||||
t2v_custom_inputs_size = 0
|
||||
|
||||
def get_component_names():
|
||||
components_list = [
|
||||
'glo_sdcn_process_mode',
|
||||
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
|
||||
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
|
||||
'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier',
|
||||
'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed',
|
||||
't2v_file','t2v_init_image', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
|
||||
't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 't2v_cn_frame_send',
|
||||
'glo_save_frames_check'
|
||||
]
|
||||
|
||||
return components_list
|
||||
|
||||
def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
|
||||
args_list = get_component_names()
|
||||
|
||||
# set default values for params that were not specified
|
||||
args_dict = {
|
||||
# video to video params
|
||||
'v2v_mode': 0,
|
||||
'v2v_prompt': '',
|
||||
'v2v_n_prompt': '',
|
||||
'v2v_prompt_styles': [],
|
||||
'v2v_init_video': None, # Always required
|
||||
|
||||
'v2v_steps': 15,
|
||||
'v2v_sampler_index': 0, # 'Euler a'
|
||||
'v2v_mask_blur': 0,
|
||||
|
||||
'v2v_inpainting_fill': 1, # original
|
||||
'v2v_restore_faces': False,
|
||||
'v2v_tiling': False,
|
||||
'v2v_n_iter': 1,
|
||||
'v2v_batch_size': 1,
|
||||
'v2v_cfg_scale': 5.5,
|
||||
'v2v_image_cfg_scale': 1.5,
|
||||
'v2v_denoising_strength': 0.75,
|
||||
'v2v_processing_strength': 0.85,
|
||||
'v2v_fix_frame_strength': 0.15,
|
||||
'v2v_seed': -1,
|
||||
'v2v_subseed': -1,
|
||||
'v2v_subseed_strength': 0,
|
||||
'v2v_seed_resize_from_h': 512,
|
||||
'v2v_seed_resize_from_w': 512,
|
||||
'v2v_seed_enable_extras': False,
|
||||
'v2v_height': 512,
|
||||
'v2v_width': 512,
|
||||
'v2v_resize_mode': 1,
|
||||
'v2v_inpaint_full_res': True,
|
||||
'v2v_inpaint_full_res_padding': 0,
|
||||
'v2v_inpainting_mask_invert': False,
|
||||
|
||||
# text to video params
|
||||
't2v_mode': 4,
|
||||
't2v_prompt': '',
|
||||
't2v_n_prompt': '',
|
||||
't2v_prompt_styles': [],
|
||||
't2v_init_img': None,
|
||||
't2v_mask_img': None,
|
||||
|
||||
't2v_steps': 15,
|
||||
't2v_sampler_index': 0, # 'Euler a'
|
||||
't2v_mask_blur': 0,
|
||||
|
||||
't2v_inpainting_fill': 1, # original
|
||||
't2v_restore_faces': False,
|
||||
't2v_tiling': False,
|
||||
't2v_n_iter': 1,
|
||||
't2v_batch_size': 1,
|
||||
't2v_cfg_scale': 5.5,
|
||||
't2v_image_cfg_scale': 1.5,
|
||||
't2v_denoising_strength': 0.75,
|
||||
't2v_processing_strength': 0.85,
|
||||
't2v_fix_frame_strength': 0.15,
|
||||
't2v_seed': -1,
|
||||
't2v_subseed': -1,
|
||||
't2v_subseed_strength': 0,
|
||||
't2v_seed_resize_from_h': 512,
|
||||
't2v_seed_resize_from_w': 512,
|
||||
't2v_seed_enable_extras': False,
|
||||
't2v_height': 512,
|
||||
't2v_width': 512,
|
||||
't2v_resize_mode': 1,
|
||||
't2v_inpaint_full_res': True,
|
||||
't2v_inpaint_full_res_padding': 0,
|
||||
't2v_inpainting_mask_invert': False,
|
||||
|
||||
't2v_override_settings': [],
|
||||
#'t2v_script_inputs': [0],
|
||||
|
||||
't2v_fps': 12,
|
||||
}
|
||||
|
||||
args = list(args)
|
||||
|
||||
for i in range(len(args_list)):
|
||||
if (args[i] is None) and (args_list[i] in args_dict):
|
||||
#args[i] = args_dict[args_list[i]]
|
||||
pass
|
||||
else:
|
||||
args_dict[args_list[i]] = args[i]
|
||||
|
||||
args_dict['v2v_script_inputs'] = args[len(args_list):len(args_list)+shared.v2v_custom_inputs_size]
|
||||
#print('v2v_script_inputs', args_dict['v2v_script_inputs'])
|
||||
args_dict['t2v_script_inputs'] = args[len(args_list)+shared.v2v_custom_inputs_size:]
|
||||
#print('t2v_script_inputs', args_dict['t2v_script_inputs'])
|
||||
return args_dict
|
||||
|
||||
def get_mode_args(mode, args_dict):
|
||||
mode_args_dict = {}
|
||||
for key, value in args_dict.items():
|
||||
if key[:3] in [mode, 'glo'] :
|
||||
mode_args_dict[key[4:]] = value
|
||||
|
||||
return mode_args_dict
|
||||
|
||||
def set_CNs_input_image(args_dict, image, set_references = False):
|
||||
for script_input in args_dict['script_inputs']:
|
||||
if type(script_input).__name__ == 'UiControlNetUnit':
|
||||
if script_input.module not in ["reference_only", "reference_adain", "reference_adain+attn"] or set_references:
|
||||
script_input.image = np.array(image)
|
||||
script_input.batch_images = [np.array(image)]
|
||||
|
||||
import time
|
||||
import datetime
|
||||
|
||||
def get_time_left(ind, length, processing_start_time):
|
||||
s_passed = int(time.time() - processing_start_time)
|
||||
time_passed = datetime.timedelta(seconds=s_passed)
|
||||
s_left = int(s_passed / ind * (length - ind))
|
||||
time_left = datetime.timedelta(seconds=s_left)
|
||||
return f"Time elapsed: {time_passed}; Time left: {time_left};"
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
from types import SimpleNamespace
|
||||
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, process_images
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.images as images
|
||||
import modules.scripts
|
||||
from modules.shared import opts, state
|
||||
from modules import devices, sd_samplers, img2img
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
# TODO: Refactor all the code below
|
||||
|
||||
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
#images = shared.listfiles(input_dir)
|
||||
images = [input_img]
|
||||
|
||||
is_inpaint_batch = False
|
||||
#if inpaint_mask_dir:
|
||||
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
# is_inpaint_batch = len(inpaint_masks) > 0
|
||||
#if is_inpaint_batch:
|
||||
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
||||
p.do_not_save_grid = True
|
||||
p.do_not_save_samples = not save_normally
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
generated_images = []
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
img = image #Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
#if is_inpaint_batch:
|
||||
# # try to find corresponding mask for an image using simple filename matching
|
||||
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# # if not found use first one ("same mask for all images" use-case)
|
||||
# if not mask_image_path in inpaint_masks:
|
||||
# mask_image_path = inpaint_masks[0]
|
||||
# mask_image = Image.open(mask_image_path)
|
||||
# p.image_mask = mask_image
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
generated_images.append(proc.images[0])
|
||||
|
||||
#for n, processed_image in enumerate(proc.images):
|
||||
# filename = os.path.basename(image)
|
||||
|
||||
# if n > 0:
|
||||
# left, right = os.path.splitext(filename)
|
||||
# filename = f"{left}-{n}{right}"
|
||||
|
||||
# if not save_normally:
|
||||
# os.makedirs(output_dir, exist_ok=True)
|
||||
# if processed_image.mode == 'RGBA':
|
||||
# processed_image = processed_image.convert("RGB")
|
||||
# processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
return generated_images
|
||||
|
||||
def img2img(args_dict):
|
||||
args = SimpleNamespace(**args_dict)
|
||||
override_settings = create_override_settings_dict(args.override_settings)
|
||||
|
||||
is_batch = args.mode == 5
|
||||
|
||||
if args.mode == 0: # img2img
|
||||
image = args.init_img.convert("RGB")
|
||||
mask = None
|
||||
elif args.mode == 1: # img2img sketch
|
||||
image = args.sketch.convert("RGB")
|
||||
mask = None
|
||||
elif args.mode == 2: # inpaint
|
||||
image, mask = args.init_img_with_mask["image"], args.init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert("RGB")
|
||||
elif args.mode == 3: # inpaint sketch
|
||||
image = args.inpaint_color_sketch
|
||||
orig = args.inpaint_color_sketch_orig or args.inpaint_color_sketch
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - args.mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(args.mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif args.mode == 4: # inpaint upload mask
|
||||
#image = args.init_img_inpaint
|
||||
#mask = args.init_mask_inpaint
|
||||
|
||||
image = args.init_img.convert("RGB")
|
||||
mask = args.mask_img.convert("L")
|
||||
else:
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= args.denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.n_prompt,
|
||||
styles=args.prompt_styles,
|
||||
seed=args.seed,
|
||||
subseed=args.subseed,
|
||||
subseed_strength=args.subseed_strength,
|
||||
seed_resize_from_h=args.seed_resize_from_h,
|
||||
seed_resize_from_w=args.seed_resize_from_w,
|
||||
seed_enable_extras=args.seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[args.sampler_index].name,
|
||||
batch_size=args.batch_size,
|
||||
n_iter=args.n_iter,
|
||||
steps=args.steps,
|
||||
cfg_scale=args.cfg_scale,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
restore_faces=args.restore_faces,
|
||||
tiling=args.tiling,
|
||||
init_images=[image],
|
||||
mask=mask,
|
||||
mask_blur=args.mask_blur,
|
||||
inpainting_fill=args.inpainting_fill,
|
||||
resize_mode=args.resize_mode,
|
||||
denoising_strength=args.denoising_strength,
|
||||
image_cfg_scale=args.image_cfg_scale,
|
||||
inpaint_full_res=args.inpaint_full_res,
|
||||
inpaint_full_res_padding=args.inpaint_full_res_padding,
|
||||
inpainting_mask_invert=args.inpainting_mask_invert,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args.script_inputs
|
||||
|
||||
#if shared.cmd_opts.enable_console_prompts:
|
||||
# print(f"\nimg2img: {args.prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = args.mask_blur
|
||||
|
||||
'''
|
||||
if is_batch:
|
||||
...
|
||||
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
# process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args.script_inputs)
|
||||
# processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
'''
|
||||
|
||||
generated_images = process_img(p, image, None, '', args.script_inputs)
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
#if opts.samples_log_stdout:
|
||||
# print(generation_info_js)
|
||||
|
||||
#if opts.do_not_show_images:
|
||||
# processed.images = []
|
||||
|
||||
#print(generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments))
|
||||
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
||||
def txt2img(args_dict):
|
||||
args = SimpleNamespace(**args_dict)
|
||||
override_settings = create_override_settings_dict(args.override_settings)
|
||||
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=args.prompt,
|
||||
styles=args.prompt_styles,
|
||||
negative_prompt=args.n_prompt,
|
||||
seed=args.seed,
|
||||
subseed=args.subseed,
|
||||
subseed_strength=args.subseed_strength,
|
||||
seed_resize_from_h=args.seed_resize_from_h,
|
||||
seed_resize_from_w=args.seed_resize_from_w,
|
||||
seed_enable_extras=args.seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers[args.sampler_index].name,
|
||||
batch_size=args.batch_size,
|
||||
n_iter=args.n_iter,
|
||||
steps=args.steps,
|
||||
cfg_scale=args.cfg_scale,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
restore_faces=args.restore_faces,
|
||||
tiling=args.tiling,
|
||||
#enable_hr=args.enable_hr,
|
||||
#denoising_strength=args.denoising_strength if enable_hr else None,
|
||||
#hr_scale=hr_scale,
|
||||
#hr_upscaler=hr_upscaler,
|
||||
#hr_second_pass_steps=hr_second_pass_steps,
|
||||
#hr_resize_x=hr_resize_x,
|
||||
#hr_resize_y=hr_resize_y,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.script_args = args.script_inputs
|
||||
|
||||
#if cmd_opts.enable_console_prompts:
|
||||
# print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args.script_inputs)
|
||||
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
#if opts.samples_log_stdout:
|
||||
# print(generation_info_js)
|
||||
|
||||
#if opts.do_not_show_images:
|
||||
# processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
||||
|
||||
import json
|
||||
def get_json(obj):
|
||||
return json.loads(
|
||||
json.dumps(obj, default=lambda o: getattr(o, '__dict__', str(o)))
|
||||
)
|
||||
|
||||
def export_settings(*args):
|
||||
args_dict = args_to_dict(*args)
|
||||
if args[0] == 'vid2vid':
|
||||
args_dict = get_mode_args('v2v', args_dict)
|
||||
elif args[0] == 'txt2vid':
|
||||
args_dict = get_mode_args('t2v', args_dict)
|
||||
else:
|
||||
msg = f"Unsupported processing mode: '{args[0]}'"
|
||||
raise Exception(msg)
|
||||
|
||||
# convert CN params into a readable dict
|
||||
cn_remove_list = ['low_vram', 'is_ui', 'input_mode', 'batch_images', 'output_dir', 'loopback', 'image']
|
||||
|
||||
args_dict['ControlNets'] = []
|
||||
for script_input in args_dict['script_inputs']:
|
||||
if type(script_input).__name__ == 'UiControlNetUnit':
|
||||
cn_values_dict = get_json(script_input)
|
||||
if cn_values_dict['enabled']:
|
||||
for key in cn_remove_list:
|
||||
if key in cn_values_dict: del cn_values_dict[key]
|
||||
args_dict['ControlNets'].append(cn_values_dict)
|
||||
|
||||
# remove unimportant values
|
||||
remove_list = ['save_frames_check', 'restore_faces', 'prompt_styles', 'mask_blur', 'inpainting_fill', 'tiling', 'n_iter', 'batch_size', 'subseed', 'subseed_strength', 'seed_resize_from_h', \
|
||||
'seed_resize_from_w', 'seed_enable_extras', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'file', 'denoising_strength', \
|
||||
'override_settings', 'script_inputs', 'init_img', 'mask_img', 'mode', 'init_video']
|
||||
|
||||
for key in remove_list:
|
||||
if key in args_dict: del args_dict[key]
|
||||
|
||||
return json.dumps(args_dict, indent=2, default=lambda o: getattr(o, '__dict__', str(o)))
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
import sys, os
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
try:
|
||||
from modules import lowvram
|
||||
except ImportError:
|
||||
lowvram = None
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
import gc
|
||||
import cv2
|
||||
import gradio as gr
|
||||
|
||||
import time
|
||||
import skimage
|
||||
import datetime
|
||||
|
||||
from scripts.core.flow_utils import RAFT_estimate_flow, RAFT_clear_memory, compute_diff_map
|
||||
from scripts.core import utils
|
||||
|
||||
class sdcn_anim_tmp:
|
||||
prepear_counter = 0
|
||||
process_counter = 0
|
||||
input_video = None
|
||||
output_video = None
|
||||
curr_frame = None
|
||||
prev_frame = None
|
||||
prev_frame_styled = None
|
||||
prev_frame_alpha_mask = None
|
||||
fps = None
|
||||
total_frames = None
|
||||
prepared_frames = None
|
||||
prepared_next_flows = None
|
||||
prepared_prev_flows = None
|
||||
frames_prepared = False
|
||||
|
||||
def read_frame_from_video():
|
||||
# Reading video file
|
||||
if sdcn_anim_tmp.input_video.isOpened():
|
||||
ret, cur_frame = sdcn_anim_tmp.input_video.read()
|
||||
if cur_frame is not None:
|
||||
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
cur_frame = None
|
||||
sdcn_anim_tmp.input_video.release()
|
||||
|
||||
return cur_frame
|
||||
|
||||
def get_cur_stat():
|
||||
stat = f'Frames prepared: {sdcn_anim_tmp.prepear_counter + 1} / {sdcn_anim_tmp.total_frames}; '
|
||||
stat += f'Frames processed: {sdcn_anim_tmp.process_counter + 1} / {sdcn_anim_tmp.total_frames}; '
|
||||
return stat
|
||||
|
||||
def clear_memory_from_sd():
|
||||
if shared.sd_model is not None:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
if lowvram:
|
||||
try:
|
||||
lowvram.send_everything_to_cpu()
|
||||
except Exception as e:
|
||||
...
|
||||
del shared.sd_model
|
||||
shared.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
def start_process(*args):
|
||||
processing_start_time = time.time()
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||
|
||||
sdcn_anim_tmp.process_counter = 0
|
||||
sdcn_anim_tmp.prepear_counter = 0
|
||||
|
||||
# Open the input video file
|
||||
sdcn_anim_tmp.input_video = cv2.VideoCapture(args_dict['file'].name)
|
||||
|
||||
# Get useful info from the source video
|
||||
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
|
||||
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
loop_iterations = (sdcn_anim_tmp.total_frames-1) * 2
|
||||
|
||||
# Create an output video file with the same fps, width, and height as the input video
|
||||
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
|
||||
output_video_folder = os.path.splitext(output_video_name)[0]
|
||||
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
|
||||
|
||||
if args_dict['save_frames_check']:
|
||||
os.makedirs(output_video_folder, exist_ok=True)
|
||||
|
||||
def save_result_to_image(image, ind):
|
||||
if args_dict['save_frames_check']:
|
||||
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
||||
|
||||
sdcn_anim_tmp.output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), sdcn_anim_tmp.fps, (args_dict['width'], args_dict['height']))
|
||||
|
||||
curr_frame = read_frame_from_video()
|
||||
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
||||
sdcn_anim_tmp.prepared_frames = np.zeros((11, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
|
||||
sdcn_anim_tmp.prepared_next_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
|
||||
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
|
||||
sdcn_anim_tmp.prepared_frames[0] = curr_frame
|
||||
|
||||
args_dict['init_img'] = Image.fromarray(curr_frame)
|
||||
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
#print('Processed frame ', 0)
|
||||
|
||||
sdcn_anim_tmp.curr_frame = curr_frame
|
||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||
utils.shared.is_interrupted = False
|
||||
|
||||
save_result_to_image(processed_frame, 1)
|
||||
stat = get_cur_stat() + utils.get_time_left(1, loop_iterations, processing_start_time)
|
||||
yield stat, sdcn_anim_tmp.curr_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
for step in range(loop_iterations):
|
||||
if utils.shared.is_interrupted: break
|
||||
|
||||
args_dict = utils.args_to_dict(*args)
|
||||
args_dict = utils.get_mode_args('v2v', args_dict)
|
||||
|
||||
occlusion_mask = None
|
||||
prev_frame = None
|
||||
curr_frame = sdcn_anim_tmp.curr_frame
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
prepare_steps = 10
|
||||
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
|
||||
#clear_memory_from_sd()
|
||||
device = devices.get_optimal_device()
|
||||
|
||||
curr_frame = read_frame_from_video()
|
||||
if curr_frame is not None:
|
||||
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
|
||||
prev_frame = sdcn_anim_tmp.prev_frame.copy()
|
||||
|
||||
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, curr_frame, device=device)
|
||||
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
cn = sdcn_anim_tmp.prepear_counter % 10
|
||||
if sdcn_anim_tmp.prepear_counter % 10 == 0:
|
||||
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
|
||||
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
|
||||
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
|
||||
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
|
||||
#print('Prepared frame ', cn+1)
|
||||
|
||||
sdcn_anim_tmp.prev_frame = curr_frame.copy()
|
||||
|
||||
sdcn_anim_tmp.prepear_counter += 1
|
||||
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
|
||||
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
|
||||
curr_frame is None:
|
||||
# Remove RAFT from memory
|
||||
RAFT_clear_memory()
|
||||
sdcn_anim_tmp.frames_prepared = True
|
||||
else:
|
||||
# process frame
|
||||
sdcn_anim_tmp.frames_prepared = False
|
||||
|
||||
cn = sdcn_anim_tmp.process_counter % 10
|
||||
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1][...,:3]
|
||||
prev_frame = sdcn_anim_tmp.prepared_frames[cn][...,:3]
|
||||
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
|
||||
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
|
||||
|
||||
### STEP 1
|
||||
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled, args_dict)
|
||||
warped_styled_frame_ = warped_styled_frame.copy()
|
||||
|
||||
#fl_w, fl_h = prev_flow.shape[:2]
|
||||
#prev_flow_n = prev_flow / np.array([fl_h,fl_w])
|
||||
#flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None] * 20, 0, 1)
|
||||
#alpha_mask = alpha_mask * flow_mask
|
||||
|
||||
if sdcn_anim_tmp.process_counter > 0 and args_dict['occlusion_mask_trailing']:
|
||||
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
|
||||
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
|
||||
|
||||
# alpha_mask = np.round(alpha_mask * 8) / 8 #> 0.3
|
||||
alpha_mask = np.clip(alpha_mask, 0, 1)
|
||||
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
|
||||
warped_styled_frame = curr_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
||||
|
||||
# process current frame
|
||||
# TODO: convert args_dict into separate dict that stores only params necessery for img2img processing
|
||||
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
|
||||
img2img_args_dict['denoising_strength'] = args_dict['processing_strength']
|
||||
if args_dict['step_1_processing_mode'] == 0: # Process full image then blend in occlusions
|
||||
img2img_args_dict['mode'] = 0
|
||||
img2img_args_dict['mask_img'] = None #Image.fromarray(occlusion_mask)
|
||||
elif args_dict['step_1_processing_mode'] == 1: # Inpaint occlusions
|
||||
img2img_args_dict['mode'] = 4
|
||||
img2img_args_dict['mask_img'] = Image.fromarray(occlusion_mask)
|
||||
else:
|
||||
raise Exception('Incorrect step 1 processing mode!')
|
||||
|
||||
blend_alpha = args_dict['step_1_blend_alpha']
|
||||
init_img = warped_styled_frame * (1 - blend_alpha) + curr_frame * blend_alpha
|
||||
img2img_args_dict['init_img'] = Image.fromarray(np.clip(init_img, 0, 255).astype(np.uint8))
|
||||
img2img_args_dict['seed'] = args_dict['step_1_seed']
|
||||
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
|
||||
processed_frame = np.array(processed_frames[0])[...,:3]
|
||||
|
||||
# normalizing the colors
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
|
||||
|
||||
#processed_frame = processed_frame * 0.94 + curr_frame * 0.06
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
|
||||
|
||||
### STEP 2
|
||||
if args_dict['fix_frame_strength'] > 0:
|
||||
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
|
||||
img2img_args_dict['mode'] = 0
|
||||
img2img_args_dict['init_img'] = Image.fromarray(processed_frame)
|
||||
img2img_args_dict['mask_img'] = None
|
||||
img2img_args_dict['denoising_strength'] = args_dict['fix_frame_strength']
|
||||
img2img_args_dict['seed'] = args_dict['step_2_seed']
|
||||
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
|
||||
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
|
||||
processed_frame = np.array(processed_frames[0])
|
||||
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
|
||||
|
||||
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
|
||||
|
||||
# Write the frame to the output video
|
||||
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
|
||||
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
|
||||
sdcn_anim_tmp.output_video.write(frame_out)
|
||||
|
||||
sdcn_anim_tmp.process_counter += 1
|
||||
#if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
|
||||
# sdcn_anim_tmp.input_video.release()
|
||||
# sdcn_anim_tmp.output_video.release()
|
||||
# sdcn_anim_tmp.prev_frame = None
|
||||
|
||||
save_result_to_image(processed_frame, sdcn_anim_tmp.process_counter + 1)
|
||||
|
||||
stat = get_cur_stat() + utils.get_time_left(step+2, loop_iterations+1, processing_start_time)
|
||||
yield stat, curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
|
||||
|
||||
RAFT_clear_memory()
|
||||
|
||||
sdcn_anim_tmp.input_video.release()
|
||||
sdcn_anim_tmp.output_video.release()
|
||||
|
||||
curr_frame = gr.Image.update()
|
||||
occlusion_mask = gr.Image.update()
|
||||
warped_styled_frame_ = gr.Image.update()
|
||||
processed_frame = gr.Image.update()
|
||||
|
||||
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)
|
||||