Compare commits

...

99 Commits
v0.4 ... main

Author SHA1 Message Date
Alexey Borsky 05e265ff3f
Fix typo in maintenance warning message 2026-04-02 01:00:50 +03:00
Alexey Borsky f2709f3990
Update README with maintenance warning and link
Added warning about repository maintenance status and provided a link to a modern version.
2026-04-02 01:00:31 +03:00
Alexey Borsky 2e5e09f3c3
Merge pull request #212 from doctorjei/main
This commit removes deprecated elements that break in some WebUI variants while maintaining compatibility.
2026-01-12 00:56:22 +03:00
Jeremiah Blanchard 052fdec082 -Removed deprecated usage in Gradio (remains compatible back to AUTOMATIC1111)
-Removed unused includes that break compatibility with other SD WebUI variants
-Added check for 'lowvram' module include to allow use in WebUI Force Neo / Classic
2026-01-11 00:24:35 -05:00
Alexey Borsky 4f07b5e6d2
Update LICENSE 2023-11-22 00:08:51 +03:00
Alexey Borsky ab1f1583b7
Merge pull request #191 from alexbofa/main
Update utils.py & txt2vid.py & vid2vid.py
2023-09-01 09:12:19 +03:00
alexbofa 48d8348de4
Update base_ui.py 2023-09-01 00:04:39 +03:00
alexbofa 3ccf7373ad
Update utils.py 2023-08-31 18:43:10 +03:00
alexbofa de1ff473f7
Update txt2vid.py 2023-08-31 18:42:48 +03:00
alexbofa 6a04d70241
Update vid2vid.py 2023-08-31 18:42:07 +03:00
Alexey Borsky 2e257bbfc3 better control for txt2vid 2023-05-30 19:49:03 +03:00
Alexey Borsky 111711fc7b fix type 2023-05-28 15:57:57 +03:00
Alexey Borsky 8bc0954ac8 remove exact version from scikit-image 2023-05-28 15:55:21 +03:00
Alexey Borsky c3c0972c0a Fixed issue #112 2023-05-27 15:38:14 +03:00
Alexey Borsky 89ba89d949 few comments for the future 2023-05-27 10:15:02 +03:00
Alexey Borsky ba3c17ef7e
Merge pull request #132 from nagolinc/main
add init_frame to txt2vid
2023-05-27 07:43:05 +03:00
Logan zoellner b26628a5be add init_frame to txt2vid 2023-05-24 11:39:31 -04:00
Alexey Borsky 3c36b8e7c5 fixed 'processing_strength' related issue 2023-05-20 04:15:16 +03:00
Alexey Borsky c3e4b42d98
Update readme.md 2023-05-18 05:33:51 +03:00
Alexey Borsky d97ead6ab8 issue #104 fix attempt 2 2023-05-16 11:45:12 +03:00
Alexey Borsky 00fbf01831 issue #104 fix attempt 2023-05-16 11:29:05 +03:00
Alexey Borsky 9e08b4c7d3 reference_only exception 2023-05-16 10:52:11 +03:00
Alexey Borsky 8d63dd5471
Update readme.md 2023-05-15 21:02:24 +03:00
Alexey Borsky e16f728512 better histogram matching. Issue #57 2023-05-14 17:25:17 +03:00
Alexey Borsky a35f446b69 Issue #92 fix 2023-05-14 16:59:16 +03:00
Alexey Borsky 9849e6389e critical fixes 2023-05-14 05:58:37 +03:00
Alexey Borsky cd400ea7b1 readme update 2023-05-14 04:49:28 +03:00
Alexey Borsky 5d1b65ee48
Merge pull request #87 from mariaWitch/patch-1
Fix Typo in Utils.py
2023-05-14 04:13:54 +03:00
Alexey Borsky 9ab15587d8
Merge branch 'main' into patch-1 2023-05-14 04:13:44 +03:00
Alexey Borsky e67bcb9264 issue #86 fix 2023-05-14 04:11:51 +03:00
Maria f1d47c954a Update base_ui.py
Typos
2023-05-14 04:11:51 +03:00
Maria 4ef0097751 Update utils.py
Fully solves the script issue by just populating txt2vid script inputs as well.
2023-05-14 04:11:51 +03:00
Maria 98fe91ceae Partial Fix in Base_ui 2023-05-14 04:11:51 +03:00
Maria 33897b7e2f Fix Typo in Utils.py
Partially Solves #86
2023-05-14 04:11:51 +03:00
Alexey Borsky 027faf1612 v0.9.2 2023-05-14 02:59:45 +03:00
Alexey Borsky 63efd3e0e3 v0.9.1 2023-05-13 23:27:27 +03:00
Maria b9d080ff4e
Merge branch 'main' into patch-1 2023-05-12 22:43:40 -04:00
Maria be9f056657
Update base_ui.py
Typos
2023-05-12 22:36:59 -04:00
Alexey Borsky eddf1a4c8f v0.9 2023-05-13 05:30:03 +03:00
Maria ca5fdb1151
Update utils.py
Fully solves the script issue by just populating txt2vid script inputs as well.
2023-05-12 17:50:05 -04:00
Maria 9c5611355d
Partial Fix in Base_ui 2023-05-12 17:47:21 -04:00
Maria 85cd721e83
Fix Typo in Utils.py
Partially Solves #86
2023-05-12 17:10:48 -04:00
Alexey Borsky 19ac530bb0 cn settings preview update 2023-05-12 03:10:48 +03:00
Alexey Borsky ea0b5e19fc best CN params preview 2023-05-12 02:40:13 +03:00
Alexey Borsky 46ae16e4cb Issue #76 fix 2023-05-12 00:02:48 +03:00
Alexey Borsky 14534d3174 v0.8 2023-05-11 07:34:29 +03:00
Alexey Borsky c987f645d6 Note 2023-05-08 01:19:50 +03:00
Alexey Borsky f33a508d3f better preview 2023-05-08 01:08:26 +03:00
Alexey Borsky 4adcaf026b v0.7 2023-05-07 22:32:30 +03:00
Alexey Borsky a553de32db comment legacy code 2023-05-06 04:13:48 +03:00
Alexey Borsky 31a8dc71d8 forbidden characters fix 2023-05-06 03:55:09 +03:00
Alexey Borsky ef5dca6d99 comment enable_console_prompts 2023-05-06 03:40:45 +03:00
Alexey Borsky eff3046d81 slight ui change 2023-05-06 02:54:30 +03:00
Alexey Borsky 115dbeacb1 text fix 2023-05-05 06:00:45 +03:00
Alexey Borsky 9b88a8f04e fixed issue with running ControlNet 2023-05-05 05:37:25 +03:00
Alexey Borsky 7e55f4d781 path references updated 2023-05-05 05:37:25 +03:00
Alexey Borsky 82d0679a51 minor fixes 2023-05-05 05:37:25 +03:00
Alexey Borsky bcfd6f994d Add only necessary RAFT code 2023-05-05 05:37:25 +03:00
Alexey Borsky 50724e8056 Delete link to RAFT 2023-05-05 05:37:25 +03:00
Alexey Borsky bb4353a264 v0.6 2023-05-05 05:37:25 +03:00
Alexey Borsky c34cfe2976
Update readme.md 2023-05-01 12:59:39 +03:00
Alexey Borsky dc2be7ba28
Update LICENSE 2023-05-01 12:58:44 +03:00
Alexey Borsky 0cb020b157 better examples 2023-04-22 01:55:41 +03:00
Alexey Borsky 3c69efe7a8 less ghosting + stable colors 2023-04-21 23:00:46 +03:00
Alexey Borsky f46b73b6f4
Update LICENSE 2023-04-21 22:21:33 +03:00
Alexey Borsky f073b25440 added -rb flag 2023-04-21 12:10:56 +03:00
Alexey Borsky 10430c5d0f readme update 2023-04-19 03:52:53 +03:00
Alexey Borsky 3e0de3b84d readme update 2023-04-19 03:52:01 +03:00
Alexey Borsky 12dcbef475 readme update 2023-04-19 03:30:48 +03:00
Alexey Borsky b112d6b2fc Merge branch 'CaptnSeraph-main' 2023-04-19 03:23:55 +03:00
CaptnSeraph bccf317b03 Update flow_utils.py
added background removal (not that its great)
2023-04-19 03:23:44 +03:00
Alexey Borsky 0b7fb7c252 code clean up 2023-04-19 03:23:44 +03:00
Alexey Borsky a8c298b7e4 readme update 2023-04-19 03:23:44 +03:00
Alexey Borsky 6a6ca687a5 readme update 2023-04-19 03:23:44 +03:00
Alexey Borsky 05260bce59 readme update 2023-04-19 03:23:44 +03:00
Alexey Borsky fac580c3b9 readme update 2023-04-19 03:23:44 +03:00
Alexey Borsky ae605b4299 added link to FloweR 2023-04-19 03:23:44 +03:00
Alexey Borsky 621e18ea56 Text to video script added 2023-04-19 03:23:44 +03:00
CaptnSeraph 7e051bbe13 Update compute_flow.py
added background removal (its still crap though)
2023-04-19 03:21:39 +03:00
CaptnSeraph cb737361d0 Update flow_utils.py
added background removal (not that its great)
2023-04-19 03:21:39 +03:00
CaptnSeraph 81c425a429 Delete vid2vid_interactive.py
dont want
2023-04-19 03:21:22 +03:00
theseraphim 22bcebb64e Update vid2vid_interactive.py
changed subprocess to call python instead of python3
2023-04-19 03:21:22 +03:00
theseraphim 22d909ac59 Update vid2vid_interactive.py
fixed cfg_scale message and added sampler selector
2023-04-19 03:21:22 +03:00
theseraphim 5d2f58151d Create vid2vid_interactive.py
created interactive version of vid2vid that includes the optional running of the computer_flow.py as a subprocess and populates the width and height based on user provided video.
also includes some error handling for a more graceful exit.
2023-04-19 03:21:22 +03:00
Alexey Borsky 13a39359ad code clean up 2023-04-19 03:01:35 +03:00
Alexey Borsky 0e8039a6d2 readme update 2023-04-19 03:00:14 +03:00
Alexey Borsky c78c9ae907 readme update 2023-04-19 02:54:29 +03:00
Alexey Borsky 729909e533 readme update 2023-04-19 02:36:25 +03:00
Alexey Borsky c0e876f4b5 readme update 2023-04-19 02:31:57 +03:00
Alexey Borsky 453bf74188 added link to FloweR 2023-04-19 02:31:00 +03:00
Alexey Borsky e5fd243585 Text to video script added 2023-04-19 02:26:53 +03:00
CaptnSeraph 085f865658
Update compute_flow.py
added background removal (its still crap though)
2023-04-16 02:26:10 +01:00
CaptnSeraph 87834fc192
Update flow_utils.py
added background removal (not that its great)
2023-04-16 02:25:47 +01:00
CaptnSeraph 3ce461efcc
Delete vid2vid_interactive.py
dont want
2023-04-16 02:25:13 +01:00
theseraphim 8287c45e31
Update vid2vid_interactive.py
changed subprocess to call python instead of python3
2023-04-13 02:09:09 +01:00
theseraphim 6dda2aaf9d
Update vid2vid_interactive.py
fixed cfg_scale message and added sampler selector
2023-04-13 02:06:52 +01:00
theseraphim e90e7e3690
Create vid2vid_interactive.py
created interactive version of vid2vid that includes the optional running of the computer_flow.py as a subprocess and populates the width and height based on user provided video.
also includes some error handling for a more graceful exit.
2023-04-13 01:48:20 +01:00
Alexey Borsky b65ad41836
Merge pull request #13 from theseraphim/patch-1
fix mp4 type error
2023-04-13 01:55:47 +03:00
theseraphim 24f77ed2fc
fix mp4 type error
MP4V to mp4v, fixes the 
"OpenCV: FFMPEG: tag 0x5634504d/'MP4V' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'"
error that ive been having. (which was only a warning... but clean is clean
2023-04-12 23:49:37 +01:00
48 changed files with 3446 additions and 160 deletions

5
.gitignore vendored
View File

@ -1,3 +1,6 @@
__pycache__/
out/
result.mp4
videos/
FP_Res/
result.mp4
*.pth

191
FloweR/model.py Normal file
View File

@ -0,0 +1,191 @@
import torch
import torch.nn as nn
import torch.functional as F
# Define the model
class FloweR(nn.Module):
def __init__(self, input_size = (384, 384), window_size = 4):
super(FloweR, self).__init__()
self.input_size = input_size
self.window_size = window_size
# 2 channels for optical flow
# 1 channel for occlusion mask
# 3 channels for next frame prediction
self.out_channels = 6
#INPUT: 384 x 384 x 4 * 3
### DOWNSCALE ###
self.conv_block_1 = nn.Sequential(
nn.Conv2d(3 * self.window_size, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 384 x 384 x 128
self.conv_block_2 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 192 x 192 x 128
self.conv_block_3 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 96 x 96 x 128
self.conv_block_4 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 48 x 48 x 128
self.conv_block_5 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 24 x 24 x 128
self.conv_block_6 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 12 x 12 x 128
self.conv_block_7 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 6 x 6 x 128
self.conv_block_8 = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 3 x 3 x 128 - 9 input tokens
### Transformer part ###
# To be done
### UPSCALE ###
self.conv_block_9 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 6 x 6 x 128
self.conv_block_10 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 12 x 12 x 128
self.conv_block_11 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 24 x 24 x 128
self.conv_block_12 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 48 x 48 x 128
self.conv_block_13 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 96 x 96 x 128
self.conv_block_14 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 192 x 192 x 128
self.conv_block_15 = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
) # 384 x 384 x 128
self.conv_block_16 = nn.Conv2d(128, self.out_channels, kernel_size=3, stride=1, padding='same')
def forward(self, input_frames):
if input_frames.size(1) != self.window_size:
raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.')
h, w = self.input_size
# batch, frames, height, width, colors
input_frames_permuted = input_frames.permute((0, 1, 4, 2, 3))
# batch, frames, colors, height, width
in_x = input_frames_permuted.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1])
### DOWNSCALE ###
block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128
block_2_out = self.conv_block_2(block_1_out) # 192 x 192 x 128
block_3_out = self.conv_block_3(block_2_out) # 96 x 96 x 128
block_4_out = self.conv_block_4(block_3_out) # 48 x 48 x 128
block_5_out = self.conv_block_5(block_4_out) # 24 x 24 x 128
block_6_out = self.conv_block_6(block_5_out) # 12 x 12 x 128
block_7_out = self.conv_block_7(block_6_out) # 6 x 6 x 128
block_8_out = self.conv_block_8(block_7_out) # 3 x 3 x 128
### UPSCALE ###
block_9_out = block_7_out + self.conv_block_9(block_8_out) # 6 x 6 x 128
block_10_out = block_6_out + self.conv_block_10(block_9_out) # 12 x 12 x 128
block_11_out = block_5_out + self.conv_block_11(block_10_out) # 24 x 24 x 128
block_12_out = block_4_out + self.conv_block_12(block_11_out) # 48 x 48 x 128
block_13_out = block_3_out + self.conv_block_13(block_12_out) # 96 x 96 x 128
block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128
block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128
block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1 + 3)
out = block_16_out.reshape(-1, self.out_channels, self.input_size[0], self.input_size[1])
### for future model training ###
device = out.get_device()
pred_flow = out[:,:2,:,:] * 255 # (-255, 255)
pred_occl = (out[:,2:3,:,:] + 1) / 2 # [0, 1]
pred_next = out[:,3:6,:,:]
# Generate sampling grids
# Create grid to upsample input
'''
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) '''
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
flow_grid = flow_grid.unsqueeze(0).to(device=device)
flow_grid = flow_grid + pred_flow
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
# batch, flow_chanels, height, width
flow_grid = flow_grid.permute(0, 2, 3, 1)
# batch, height, width, flow_chanels
previous_frame = input_frames_permuted[:, -1, :, :, :]
sampling_mode = "bilinear" if self.training else "nearest"
warped_frame = torch.nn.functional.grid_sample(previous_frame, flow_grid, mode=sampling_mode, padding_mode="reflection", align_corners=False)
alpha_mask = torch.clip(pred_occl * 10, 0, 1) * 0.04
pred_next = torch.clip(pred_next, -1, 1)
warped_frame = torch.clip(warped_frame, -1, 1)
next_frame = pred_next * alpha_mask + warped_frame * (1 - alpha_mask)
res = torch.cat((pred_flow / 255, pred_occl * 2 - 1, next_frame), dim=1)
# batch, channels, height, width
res = res.permute((0, 2, 3, 1))
# batch, height, width, channels
return res

View File

@ -1,4 +1,4 @@
License
MIT License
Copyright (c) 2023 Alexey Borsky
@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This repository can only be used for personal/research/non-commercial purposes.
However, for commercial requests, please contact us directly at
borsky.alexey@gmail.com

1
RAFT

@ -1 +0,0 @@
Subproject commit aac9dd54726caf2cf81d8661b07663e220c5586d

29
RAFT/LICENSE Normal file
View File

@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2020, princeton-vl
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

91
RAFT/corr.py Normal file
View File

@ -0,0 +1,91 @@
import torch
import torch.nn.functional as F
from RAFT.utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())

267
RAFT/extractor.py Normal file
View File

@ -0,0 +1,267 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

144
RAFT/raft.py Normal file
View File

@ -0,0 +1,144 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from RAFT.update import BasicUpdateBlock, SmallUpdateBlock
from RAFT.extractor import BasicEncoder, SmallEncoder
from RAFT.corr import CorrBlock, AlternateCorrBlock
from RAFT.utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8, device=img.device)
coords1 = coords_grid(N, H//8, W//8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

139
RAFT/update.py Normal file
View File

@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

0
RAFT/utils/__init__.py Normal file
View File

246
RAFT/utils/augmentor.py Normal file
View File

@ -0,0 +1,246 @@
import numpy as np
import random
import math
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid

132
RAFT/utils/flow_viz.py Normal file
View File

@ -0,0 +1,132 @@
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
import numpy as np
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)

137
RAFT/utils/frame_utils.py Normal file
View File

@ -0,0 +1,137 @@
import numpy as np
from PIL import Image
from os.path import *
import re
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []

82
RAFT/utils/utils.py Normal file
View File

@ -0,0 +1,82 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

BIN
examples/bonefire_1.mp4 Normal file

Binary file not shown.

BIN
examples/bonfire_1.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 902 KiB

BIN
examples/cn_settings.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

BIN
examples/diamond_4.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 451 KiB

BIN
examples/diamond_4.mp4 Normal file

Binary file not shown.

BIN
examples/flower_1.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

BIN
examples/flower_1.mp4 Normal file

Binary file not shown.

BIN
examples/flower_11.mp4 Normal file

Binary file not shown.

BIN
examples/girl_org.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

BIN
examples/girl_to_jc.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 MiB

BIN
examples/girl_to_jc.mp4 Normal file

Binary file not shown.

BIN
examples/girl_to_wc.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 MiB

BIN
examples/girl_to_wc.mp4 Normal file

Binary file not shown.

BIN
examples/gold_1.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

BIN
examples/gold_1.mp4 Normal file

Binary file not shown.

BIN
examples/macaroni_1.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

BIN
examples/macaroni_1.mp4 Normal file

Binary file not shown.

BIN
examples/tree_2.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

BIN
examples/tree_2.mp4 Normal file

Binary file not shown.

BIN
examples/ui_preview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

View File

@ -1,87 +0,0 @@
import numpy as np
import cv2
#RAFT dependencies
import sys
sys.path.append('RAFT/core')
from collections import namedtuple
import torch
import argparse
from raft import RAFT
from utils.utils import InputPadder
RAFT_model = None
def RAFT_estimate_flow(frame1, frame2, device = 'cuda'):
global RAFT_model
if RAFT_model is None:
args = argparse.Namespace(**{
'model': 'RAFT/models/raft-things.pth',
'mixed_precision': True,
'small': False,
'alternate_corr': False,
'path': ""
})
RAFT_model = torch.nn.DataParallel(RAFT(args))
RAFT_model.load_state_dict(torch.load(args.model))
RAFT_model = RAFT_model.module
RAFT_model.to(device)
RAFT_model.eval()
with torch.no_grad():
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
padder = InputPadder(frame1_torch.shape)
image1, image2 = padder.pad(frame1_torch, frame2_torch)
# estimate optical flow
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
next_flow = next_flow[0].permute(1,2,0).cpu().numpy()
prev_flow = prev_flow[0].permute(1,2,0).cpu().numpy()
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None].repeat(3, axis = -1)
return next_flow, prev_flow, occlusion_mask
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
h, w = cur_frame.shape[:2]
next_flow = cv2.resize(next_flow, (w, h))
prev_flow = cv2.resize(prev_flow, (w, h))
flow_map = -next_flow.copy()
flow_map[:,:,0] += np.arange(w)
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST)
warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST)
# compute occlusion mask
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None]
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
alpha_mask = alpha_mask.repeat(3, axis = -1)
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
alpha_mask = cv2.GaussianBlur(alpha_mask, (51,51), 5, cv2.BORDER_DEFAULT)
alpha_mask = np.clip(alpha_mask, 0, 1)
return alpha_mask, warped_frame_styled

20
install.py Normal file
View File

@ -0,0 +1,20 @@
import launch
import os
import pkg_resources
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
with open(req_file) as file:
for package in file:
try:
package = package.strip()
if '==' in package:
package_name, package_version = package.split('==')
installed_version = pkg_resources.get_distribution(package_name).version
if installed_version != package_version:
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: changing {package_name} version from {installed_version} to {package_version}")
elif not launch.is_installed(package):
launch.run_pip(f"install {package}", f"SD-CN-Animation requirement: {package}")
except Exception as e:
print(e)
print(f'Warning: Failed to install {package}.')

View File

@ -36,7 +36,7 @@ def main(args):
cur_frame = cv2.resize(cur_frame, (W, H))
if prev_frame is not None:
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, cur_frame)
next_flow, prev_flow, occlusion_mask, frame1_bg_removed, frame2_bg_removed = RAFT_estimate_flow(prev_frame, cur_frame, subtract_background=args.remove_background)
# write data into a file
flow_maps.resize(ind, axis=0)
@ -47,7 +47,10 @@ def main(args):
if args.visualize:
# show the last written frame - useful to catch any issue with the process
img_show = cv2.hconcat([cur_frame, occlusion_mask])
if args.remove_background:
img_show = cv2.hconcat([cur_frame, frame2_bg_removed, occlusion_mask])
else:
img_show = cv2.hconcat([cur_frame, occlusion_mask])
cv2.imshow('Out img', img_show)
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
@ -60,12 +63,13 @@ def main(args):
if args.visualize: cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
args = parser.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_video', help="Path to input video file", required=True)
parser.add_argument('-o', '--output_file', help="Path to output flow file. Stored in *.h5 format", required=True)
parser.add_argument('-W', '--width', help='Width of the generated flow maps', default=1024, type=int)
parser.add_argument('-H', '--height', help='Height of the generated flow maps', default=576, type=int)
parser.add_argument('-v', '--visualize', action='store_true', help='Show proceed images and occlusion maps')
parser.add_argument('-rb', '--remove_background', action='store_true', help='Remove background of the image')
args = parser.parse_args()
main(args)
main(args)

139
old_scripts/flow_utils.py Normal file
View File

@ -0,0 +1,139 @@
import numpy as np
import cv2
# RAFT dependencies
import sys
sys.path.append('RAFT/core')
from collections import namedtuple
import torch
import argparse
from raft import RAFT
from utils.utils import InputPadder
RAFT_model = None
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
def background_subtractor(frame, fgbg):
fgmask = fgbg.apply(frame)
return cv2.bitwise_and(frame, frame, mask=fgmask)
def RAFT_estimate_flow(frame1, frame2, device='cuda', subtract_background=True):
global RAFT_model
if RAFT_model is None:
args = argparse.Namespace(**{
'model': 'RAFT/models/raft-things.pth',
'mixed_precision': True,
'small': False,
'alternate_corr': False,
'path': ""
})
RAFT_model = torch.nn.DataParallel(RAFT(args))
RAFT_model.load_state_dict(torch.load(args.model))
RAFT_model = RAFT_model.module
RAFT_model.to(device)
RAFT_model.eval()
if subtract_background:
frame1 = background_subtractor(frame1, fgbg)
frame2 = background_subtractor(frame2, fgbg)
with torch.no_grad():
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
padder = InputPadder(frame1_torch.shape)
image1, image2 = padder.pad(frame1_torch, frame2_torch)
# estimate optical flow
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
return next_flow, prev_flow, occlusion_mask, frame1, frame2
# ... rest of the file ...
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled):
h, w = cur_frame.shape[:2]
#print(np.amin(next_flow), np.amax(next_flow))
#exit()
fl_w, fl_h = next_flow.shape[:2]
# normalize flow
next_flow = next_flow / np.array([fl_h,fl_w])
prev_flow = prev_flow / np.array([fl_h,fl_w])
# remove low value noise (@alexfredo suggestion)
next_flow[np.abs(next_flow) < 0.05] = 0
prev_flow[np.abs(prev_flow) < 0.05] = 0
# resize flow
next_flow = cv2.resize(next_flow, (w, h))
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
prev_flow = cv2.resize(prev_flow, (w, h))
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
# Generate sampling grids
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
flow_grid = flow_grid.unsqueeze(0)
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
flow_grid = flow_grid.permute(0, 2, 3, 1)
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy()
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
# compute occlusion mask
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None]
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
alpha_mask = alpha_mask.repeat(3, axis = -1)
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
alpha_mask = cv2.GaussianBlur(alpha_mask, (51,51), 5, cv2.BORDER_REFLECT)
alpha_mask = np.clip(alpha_mask, 0, 1)
return alpha_mask, warped_frame_styled
def frames_norm(occl): return occl / 127.5 - 1
def flow_norm(flow): return flow / 255
def occl_norm(occl): return occl / 127.5 - 1
def flow_renorm(flow): return flow * 255
def occl_renorm(occl): return (occl + 1) * 127.5

133
old_scripts/readme.md Normal file
View File

@ -0,0 +1,133 @@
# SD-CN-Animation
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
### Video to Video Examples:
<!--
[![IMAGE_ALT](https://img.youtube.com/vi/j-0niEMm6DU/0.jpg)](https://youtu.be/j-0niEMm6DU)
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
-->
</table>
<table class="center">
<tr>
<td><img src="examples/girl_org.gif" raw=true></td>
<td><img src="examples/girl_to_jc.gif" raw=true></td>
<td><img src="examples/girl_to_wc.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">Original video</td>
<td width=33% align="center">"Jessica Chastain"</td>
<td width=33% align="center">"Watercolor painting"</td>
</tr>
</table>
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
### Text to Video Examples:
</table>
<table class="center">
<tr>
<td><img src="examples/flower_1.gif" raw=true></td>
<td><img src="examples/bonfire_1.gif" raw=true></td>
<td><img src="examples/diamond_4.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">"close up of a flower"</td>
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
<td width=33% align="center">"close up of a diamond laying on the table"</td>
</tr>
<tr>
<td><img src="examples/macaroni_1.gif" raw=true></td>
<td><img src="examples/gold_1.gif" raw=true></td>
<td><img src="examples/tree_2.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">"close up of macaroni on the plate"</td>
<td width=33% align="center">"close up of golden sphere"</td>
<td width=33% align="center">"a tree standing in the winter forest"</td>
</tr>
</table>
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
## Dependencies
To install all the necessary dependencies, run this command:
```
pip install opencv-python opencv-contrib-python numpy tqdm h5py scikit-image
```
You have to set up the RAFT repository as it described here: https://github.com/princeton-vl/RAFT . Basically it just comes down to running "./download_models.sh" in RAFT folder to download the models.
## Running the scripts
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You should also have[sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have the control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly, you have to also allow the API to work with controlNet. To do so, go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings".
### Video To Video
#### Step 1.
To process the video, first of all you would need to precompute optical flow data before running web-ui with this command:
```
python3 compute_flow.py -i "path to your video" -o "path to output file with *.h5 format" -v -W width_of_the_flow_map -H height_of_the_flow_map
```
The main reason to do this step separately is to save precious GPU memory that will be useful to generate better quality images. Choose W and H parameters as high as your GPU can handle with respect to the proportion of original video resolution. Do not worry if it is higher or less then the processing resolution, flow maps will be scaled accordingly at the processing stage. This will generate quite a large file that may take up to a several gigabytes on the drive even for minute long video. If you want to process a long video consider splitting it into several parts beforehand.
#### Step 2.
Run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
```
bash webui.sh --xformers --api
```
#### Step 3.
Go to the **vid2vid.py** file and change main parameters (INPUT_VIDEO, FLOW_MAPS, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. FLOW_MAPS parameter should contain a path to the flow file that you generated at the first step. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time. Finally run the script with the command:
```
python3 vid2vid.py
```
### Text To Video
This method is still in development and works on top of Stable Diffusion and 'FloweR' - optical flow reconstruction method that is also in a yearly development stage. Do not expect much from it as it is more of a proof of a concept rather than a complete solution.
#### Step 1.
Download 'FloweR_0.1.pth' model from here: [Google drive link](https://drive.google.com/file/d/1WhzoVIw6Kdg4EjfK9LaTLqFm5dF-IJ7F/view?usp=share_link) and place it in the 'FloweR' folder.
#### Step 2.
Same as with vid2vid case, run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
```
bash webui.sh --xformers --api
```
#### Step 3.
Go to the **txt2vid.py** file and change main parameters (OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. Again, the script is simple so you may change other parameters if you want to. Finally run the script with the command:
```
python3 txt2vid.py
```
## Last version changes: v0.5
* Fixed an issue with the wrong direction of an optical flow applied to an image.
* Added text to video mode within txt2vid.py script. Make sure to update new dependencies for this script to work!
* Added a threshold for an optical flow before processing the frame to remove white noise that might appear, as it was suggested by [@alexfredo](https://github.com/alexfredo).
* Background removal at flow computation stage implemented by [@CaptnSeraph](https://github.com/CaptnSeraph), it should reduce ghosting effect in most of the videos processed with vid2vid script.
<!--
## Last version changes: v0.6
* Added separate flag '-rb' for background removal process at the flow computation stage in the compute_flow.py script.
* Added flow normalization before rescaling it, so the magnitude of the flow computed correctly at the different resolution.
* Less ghosting and color change in vid2vid mode
-->
<!--
## Potential improvements
There are several ways overall quality of animation may be improved:
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
* The quality of flow estimation might be greatly improved with a proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
-->
## Licence
This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact me directly at borsky.alexey@gmail.com

208
old_scripts/txt2vid.py Normal file
View File

@ -0,0 +1,208 @@
import requests
import cv2
import base64
import numpy as np
from tqdm import tqdm
import os
import sys
sys.path.append('FloweR/')
sys.path.append('RAFT/core')
import torch
from model import FloweR
from utils import flow_viz
from flow_utils import *
import skimage
import datetime
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
PROMPT = "people looking at flying robots. Future. People looking to the sky. Stars in the background. Dramatic light, Cinematic light. Soft lighting, high quality, film grain."
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
w,h = 768, 512 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
PROCESSING_STRENGTH = 0.85
FIX_STRENGTH = 0.35
CFG_SCALE = 5.5
APPLY_TEMPORALNET = False
APPLY_COLOR = False
VISUALIZE = True
DEVICE = 'cuda'
def to_b64(img):
img_cliped = np.clip(img, 0, 255).astype(np.uint8)
_, buffer = cv2.imencode('.png', img_cliped)
b64img = base64.b64encode(buffer).decode("utf-8")
return b64img
class controlnetRequest():
def __init__(self, b64_init_img = None, b64_prev_img = None, b64_color_img = None, ds = 0.35, w=w, h=h, mask = None, seed=-1, mode='img2img'):
self.url = f"http://localhost:7860/sdapi/v1/{mode}"
self.body = {
"init_images": [b64_init_img],
"mask": mask,
"mask_blur": 0,
"inpainting_fill": 1,
"inpainting_mask_invert": 0,
"prompt": PROMPT,
"negative_prompt": N_PROMPT,
"seed": seed,
"subseed": -1,
"subseed_strength": 0,
"batch_size": 1,
"n_iter": 1,
"steps": 15,
"cfg_scale": CFG_SCALE,
"denoising_strength": ds,
"width": w,
"height": h,
"restore_faces": False,
"eta": 0,
"sampler_index": "DPM++ 2S a",
"control_net_enabled": True,
"alwayson_scripts": {
"ControlNet":{"args": []}
},
}
if APPLY_TEMPORALNET:
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
"input_image": b64_prev_img,
"module": "none",
"model": "diff_control_sd15_temporalnet_fp16 [adc6bd97]",
"weight": 0.65,
"resize_mode": "Just Resize",
"lowvram": False,
"processor_res": 512,
"guidance_start": 0,
"guidance_end": 0.65,
"guessmode": False
})
if APPLY_COLOR:
self.body["alwayson_scripts"]["ControlNet"]["args"].append({
"input_image": b64_prev_img,
"module": "color",
"model": "t2iadapter_color_sd14v1 [8522029d]",
"weight": 0.65,
"resize_mode": "Just Resize",
"lowvram": False,
"processor_res": 512,
"guidance_start": 0,
"guidance_end": 0.65,
"guessmode": False
})
def sendRequest(self):
# Request to web-ui
data_js = requests.post(self.url, json=self.body).json()
# Convert the byte array to a NumPy array
image_bytes = base64.b64decode(data_js["images"][0])
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
# Convert the NumPy array to a cv2 image
out_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
return out_image
if VISUALIZE: cv2.namedWindow('Out img')
# Create an output video file with the same fps, width, and height as the input video
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), 15, (w, h))
prev_frame = None
prev_frame_styled = None
# Instantiate the model
model = FloweR(input_size = (h, w))
model.load_state_dict(torch.load('FloweR/FloweR_0.1.1.pth'))
# Move the model to the device
model = model.to(DEVICE)
init_frame = controlnetRequest(mode='txt2img', ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
output_video.write(init_frame)
prev_frame = init_frame
clip_frames = np.zeros((4, h, w, 3), dtype=np.uint8)
color_shift = np.zeros((0, 3))
color_scale = np.zeros((0, 3))
for ind in tqdm(range(450)):
clip_frames = np.roll(clip_frames, -1, axis=0)
clip_frames[-1] = prev_frame
clip_frames_torch = frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
with torch.no_grad():
pred_data = model(clip_frames_torch.unsqueeze(0))[0]
pred_flow = flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
flow_map = pred_flow.copy()
flow_map[:,:,0] += np.arange(w)
flow_map[:,:,1] += np.arange(h)[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_CUBIC, borderMode = cv2.BORDER_REFLECT_101)
out_image = warped_frame.copy()
out_image = controlnetRequest(
b64_init_img = to_b64(out_image),
b64_prev_img = to_b64(prev_frame),
b64_color_img = to_b64(warped_frame),
mask = to_b64(pred_occl),
ds=PROCESSING_STRENGTH, w=w, h=h).sendRequest()
out_image = controlnetRequest(
b64_init_img = to_b64(out_image),
b64_prev_img = to_b64(prev_frame),
b64_color_img = to_b64(warped_frame),
mask = None,
ds=FIX_STRENGTH, w=w, h=h).sendRequest()
# These step is necessary to reduce color drift of the image that some models may cause
out_image = skimage.exposure.match_histograms(out_image, init_frame, multichannel=True, channel_axis=-1)
output_video.write(out_image)
if SAVE_FRAMES:
if not os.path.isdir('out'): os.makedirs('out')
cv2.imwrite(f'out/{ind+1:05d}.png', out_image)
pred_flow_img = flow_viz.flow_to_image(pred_flow)
frames_img = cv2.hconcat(list(clip_frames))
data_img = cv2.hconcat([pred_flow_img, pred_occl, warped_frame, out_image])
cv2.imshow('Out img', cv2.vconcat([frames_img, data_img]))
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing
prev_frame = out_image.copy()
# Release the input and output video files
output_video.release()
# Close all windows
if VISUALIZE: cv2.destroyAllWindows()

View File

@ -8,22 +8,25 @@ import os
import h5py
from flow_utils import compute_diff_map
INPUT_VIDEO = "input.mp4"
FLOW_MAPS = "flow.h5"
OUTPUT_VIDEO = "result.mp4"
import skimage
import datetime
PROMPT = "marble statue"
INPUT_VIDEO = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.mp4"
FLOW_MAPS = "/media/alex/ded3efe6-5825-429d-ac89-7ded676a2b6d/media/Peter_Gabriel/pexels-monstera-5302599-4096x2160-30fps.h5"
OUTPUT_VIDEO = f'videos/result_{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.mp4'
PROMPT = "Underwater shot Peter Gabriel with closed eyes in Peter Gabriel's music video. 80's music video. VHS style. Dramatic light, Cinematic light. RAW photo, 8k uhd, dslr, soft lighting, high quality, film grain."
N_PROMPT = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
w,h = 1152, 640 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
w,h = 1088, 576 # Width and height of the processed image. Note that actual image processed would be a W x H resolution.
START_FROM_IND = 0 # index of a frame to start a processing from. Might be helpful with long animations where you need to restart the script multiple times
SAVE_FRAMES = True # saves individual frames into 'out' folder if set True. Again might be helpful with long animations
PROCESSING_STRENGTH = 0.85
PROCESSING_STRENGTH = 0.95
BLUR_FIX_STRENGTH = 0.15
APPLY_HED = False
APPLY_CANNY = True
APPLY_HED = True
APPLY_CANNY = False
APPLY_DEPTH = False
GUESSMODE = False
@ -72,12 +75,12 @@ class controlnetRequest():
"input_image": b64_hed_img,
"module": "hed",
"model": "control_hed-fp16 [13fee50b]",
"weight": 0.85,
"weight": 0.65,
"resize_mode": "Just Resize",
"lowvram": False,
"processor_res": 512,
"guidance_start": 0,
"guidance_end": 0.85,
"guidance_end": 0.65,
"guessmode": GUESSMODE
})
@ -136,10 +139,11 @@ fps = int(input_video.get(cv2.CAP_PROP_FPS))
total_frames = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT))
# Create an output video file with the same fps, width, and height as the input video
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'MP4V'), fps, (w, h))
output_video = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
prev_frame = None
prev_frame_styled = None
#init_image = None
# reading flow maps in a stream manner
with h5py.File(FLOW_MAPS, 'r') as f:
@ -172,6 +176,8 @@ with h5py.File(FLOW_MAPS, 'r') as f:
alpha_img = out_image.copy()
out_image_ = out_image.copy()
warped_styled = out_image.copy()
#init_image = out_image.copy()
else:
# Resize the frame to proper resolution
frame = cv2.resize(cur_frame, (w, h))
@ -188,13 +194,18 @@ with h5py.File(FLOW_MAPS, 'r') as f:
alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_img = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# normalizing the colors
out_image = skimage.exposure.match_histograms(out_image, frame, multichannel=False, channel_axis=-1)
out_image = out_image.astype(float) * alpha_mask + warped_styled.astype(float) * (1 - alpha_mask)
out_image_ = (out_image * 0.65 + warped_styled * 0.35)
#out_image = skimage.exposure.match_histograms(out_image, prev_frame, multichannel=True, channel_axis=-1)
#out_image_ = (out_image * 0.65 + warped_styled * 0.35)
# Bluring issue fix via additional processing
out_image_fixed = controlnetRequest(to_b64(out_image_), to_b64(frame), BLUR_FIX_STRENGTH, w, h, mask = None, seed=8888).sendRequest()
out_image_fixed = controlnetRequest(to_b64(out_image), to_b64(frame), BLUR_FIX_STRENGTH, w, h, mask = None, seed=8888).sendRequest()
# Write the frame to the output video
frame_out = np.clip(out_image_fixed, 0, 255).astype(np.uint8)
@ -202,8 +213,11 @@ with h5py.File(FLOW_MAPS, 'r') as f:
if VISUALIZE:
# show the last written frame - useful to catch any issue with the process
img_show = cv2.hconcat([frame_out, alpha_img])
cv2.imshow('Out img', img_show)
warped_styled = np.clip(warped_styled, 0, 255).astype(np.uint8)
img_show_top = cv2.hconcat([frame, warped_styled])
img_show_bot = cv2.hconcat([frame_out, alpha_img])
cv2.imshow('Out img', cv2.vconcat([img_show_top, img_show_bot]))
cv2.setWindowTitle("Out img", str(ind+1))
if cv2.waitKey(1) & 0xFF == ord('q'): exit() # press Q to close the script while processing

125
readme.md
View File

@ -1,49 +1,92 @@
# SD-CN-Animation
This script allows you to automate video stylization task using StableDiffusion and ControlNet. It uses a simple optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. Here is an example of a video made with this script:
> [!Warning]
> This repository is no longer maintained. If you are looking for more modern ComfyUI version of this tool please check out this repository, this might be what you are actually looking for: https://github.com/pxl-pshr/ComfyUI-SD-CN-Animation
[![IMAGE_ALT](https://img.youtube.com/vi/j-0niEMm6DU/0.jpg)](https://youtu.be/j-0niEMm6DU)
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an occlusion mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.
This script can also be using to swap the person in the video like in this example: https://youtube.com/shorts/be93_dIeZWU
![sd-cn-animation ui preview](examples/ui_preview.png)
sd-cn-animation ui preview
## Dependencies
To install all the necessary dependencies, run this command:
**In vid2vid mode do not forget to activate ControlNet model to achieve better results. Without it the resulting video might be quite choppy. Do not put any images in CN as the frames would pass automatically from the video.**
Here are CN parameters that seem to give the best results so far:
![sd-cn-animation cn params](examples/cn_settings.png)
### Video to Video Examples:
</table>
<table class="center">
<tr>
<td><img src="examples/girl_org.gif" raw=true></td>
<td><img src="examples/girl_to_jc.gif" raw=true></td>
<td><img src="examples/girl_to_wc.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">Original video</td>
<td width=33% align="center">"Jessica Chastain"</td>
<td width=33% align="center">"Watercolor painting"</td>
</tr>
</table>
Examples presented are generated at 1024x576 resolution using the 'realisticVisionV13_v13' model as a base. They were cropt, downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder.
### Text to Video Examples:
</table>
<table class="center">
<tr>
<td><img src="examples/flower_1.gif" raw=true></td>
<td><img src="examples/bonfire_1.gif" raw=true></td>
<td><img src="examples/diamond_4.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">"close up of a flower"</td>
<td width=33% align="center">"bonfire near the camp in the mountains at night"</td>
<td width=33% align="center">"close up of a diamond laying on the table"</td>
</tr>
<tr>
<td><img src="examples/macaroni_1.gif" raw=true></td>
<td><img src="examples/gold_1.gif" raw=true></td>
<td><img src="examples/tree_2.gif" raw=true></td>
</tr>
<tr>
<td width=33% align="center">"close up of macaroni on the plate"</td>
<td width=33% align="center">"close up of golden sphere"</td>
<td width=33% align="center">"a tree standing in the winter forest"</td>
</tr>
</table>
All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.
## Installing the extension
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation.git'. Leave 'Local directory name' field empty. Then just press 'Install' button. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.
## Known issues
* If you see error like this ```IndexError: list index out of range``` try to restart webui, it should fix it. If the issue still prevelent try to uninstall and reinstall scikit-image==0.19.2 with no --no-cache-dir flag like this.
```
pip install opencv-python opencv-contrib-python numpy tqdm h5py
pip uninstall scikit-image
pip install scikit-image==0.19.2 --no-cache-dir
```
You have to set up the RAFT repository as it described here: https://github.com/princeton-vl/RAFT . Basically it just comes down to running "./download_models.sh" in RAFT folder to download the models.
* The extension might work incorrectly if 'Apply color correction to img2img results to match original colors.' option is enabled. Make sure to disable it in 'Settings' tab -> 'Stable Diffusion' section.
* If you have an error like 'Need to enable queue to use generators.', please update webui to the latest version. Beware that only [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is fully supported.
* The extension is not compatible with Macs. If you have a case that extension is working for you or do you know how to make it compatible, please open a new discussion.
## Running the script
This script works on top of [Automatic1111/web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) interface via API. To run this script you have to set it up first. You should also have[sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension installed. You need to have control_hed-fp16 model installed. If you have web-ui with ControlNet working correctly, you have to also allow API to work with controlNet. To do so, go to the web-ui settings -> ControlNet tab -> Set "Allow other script to control this extension" checkbox to active and set "Multi ControlNet: Max models amount (requires restart)" to more then 2 -> press "Apply settings".
## Last version changes: v0.9
* Fixed issues #69, #76, #91, #92.
* Fixed an issue in vid2vid mode when an occlusion mask computed from the optical flow may include unnecessary parts (where flow is non-zero).
* Added 'Extra params' in vid2vid mode for more fine-grain controls of the processing pipeline.
* Better default parameters set for vid2vid pipeline.
* In txt2vid mode after the first frame is generated the seed is now automatically set to -1 to prevent blurring issues.
* Added an option to save resulting frames into a folder alongside the video.
* Added ability to export current parameters in a human readable form as a json.
* Interpolation mode in the flow-applying stage is set to nearest to reduce overtime image blurring.
* Added ControlNet to txt2vid mode as well as fixing #86 issue, thanks to [@mariaWitch](https://github.com/mariaWitch)
* Fixed a major issue when ConrtolNet used wrong input images. Because of this vid2vid results were way worse than they should be.
* Text to video mode now supports video as a guidance for ControlNet. It allows to create much stronger video stylizations.
### Step 1.
To process the video, first of all you would need to precompute optical flow data before running web-ui with this command:
```
python3 compute_flow.py -i {path to your video} -o {path to output *.h5} -v -W {width of the flow map} -H {height of the flow map}
```
The main reason to do this step separately is to save precious GPU memory that will be useful to generate better quality images. Choose W and H parameters as high as your GPU can handle with respect to proportion of original video resolution. Do not worry if it higher or less then the processing resolution, flow maps will be scaled accordingly at the processing stage. This will generate quite a large file that may take up to a several gigabytes on the drive even for minute long video. If you want to process a long video consider splitting it into several parts beforehand.
### Step 2.
Run web-ui with '--api' flag. It is also better to use '--xformers' flag, as you would need to have the highest resolution possible and using xformers memory optimization will greatly help.
```
bash webui.sh --xformers --api
```
### Step 3.
Go to the **vid2vid.py** file and change main parameters (INPUT_VIDEO, FLOW_MAPS, OUTPUT_VIDEO, PROMPT, N_PROMPT, W, H) to the ones you need for your project. FLOW_MAPS parameter should contain a path to the flow file that you generated at the first step. The script is pretty simple so you may change other parameters as well, although I would recommend to leave them as is for the first time. Finally run the script with the command:
```
python3 vid2vid.py
```
## Last version changes: v0.4
* Fixed issue with extreme blur accumulating at the static parts of the video.
* The order of processing was changed to achieve the best quality at different domains.
* Optical flow computation isolated into a separate script for better GPU memory management. Check out the instruction for a new processing pipeline.
## Potential improvements
There are several ways overall quality of animation may be improved:
* You may use a separate processing for each camera position to get a more consistent style of the characters and less ghosting.
* Because the quality of the video depends on how good optical flow was estimated it might be beneficial to use high frame rate video as a source, so it would be easier to guess the flow properly.
* The quality of flow estimation might be greatly improved with proper flow estimation model like this one: https://github.com/autonomousvision/unimatch .
## Licence
This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact me directly at borsky.alexey@gmail.com
<!--
* ControlNet with preprocessers like "reference_only", "reference_adain", "reference_adain+attn" are not reseted with video frames to have an ability to control style of the video.
* Fixed an issue because of witch 'processing_strength' UI parameters does not actually affected denoising strength at the fist processing step.
* Fixed issue #112. It will not try to reinstall requirements at every start of webui.
* Some improvements in text 2 video method.
* Parameters used to generated a video now automatically saved in video's folder.
* Added ability to control what frame will be send to CN in text to video mode.
-->

1
requirements.txt Normal file
View File

@ -0,0 +1 @@
scikit-image

252
scripts/base_ui.py Normal file
View File

@ -0,0 +1,252 @@
import sys, os
import gradio as gr
import modules
from types import SimpleNamespace
from modules import script_callbacks, shared
from modules.shared import cmd_opts, opts
from modules.call_queue import wrap_gradio_gpu_call
from modules.ui_components import ToolButton, FormRow, FormGroup
from modules.ui import create_override_settings_dropdown
import modules.scripts as scripts
from modules.sd_samplers import samplers_for_img2img
from scripts.core import vid2vid, txt2vid, utils
import traceback
def V2VArgs():
seed = -1
width = 1024
height = 576
cfg_scale = 5.5
steps = 15
prompt = ""
n_prompt = "text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
processing_strength = 0.85
fix_frame_strength = 0.15
return locals()
def T2VArgs():
seed = -1
width = 768
height = 512
cfg_scale = 5.5
steps = 15
prompt = ""
n_prompt = "((blur, blurr, blurred, blurry, fuzzy, unclear, unfocus, bocca effect)), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
processing_strength = 0.75
fix_frame_strength = 0.35
return locals()
def setup_common_values(mode, d):
with gr.Row():
width = gr.Slider(label='Width', minimum=64, maximum=2048, step=64, value=d.width, interactive=True)
height = gr.Slider(label='Height', minimum=64, maximum=2048, step=64, value=d.height, interactive=True)
with gr.Row(elem_id=f'{mode}_prompt_toprow'):
prompt = gr.Textbox(label='Prompt', lines=3, interactive=True, elem_id=f"{mode}_prompt", placeholder="Enter your prompt here...")
with gr.Row(elem_id=f'{mode}_n_prompt_toprow'):
n_prompt = gr.Textbox(label='Negative prompt', lines=3, interactive=True, elem_id=f"{mode}_n_prompt", value=d.n_prompt)
with gr.Row():
cfg_scale = gr.Slider(label='CFG scale', minimum=1, maximum=100, step=1, value=d.cfg_scale, interactive=True)
with gr.Row():
seed = gr.Number(label='Seed (this parameter controls how the first frame looks like and the color distribution of the consecutive frames as they are dependent on the first one)', value = d.seed, Interactive = True, precision=0)
with gr.Row():
processing_strength = gr.Slider(label="Processing strength (Step 1)", value=d.processing_strength, minimum=0, maximum=1, step=0.05, interactive=True)
fix_frame_strength = gr.Slider(label="Fix frame strength (Step 2)", value=d.fix_frame_strength, minimum=0, maximum=1, step=0.05, interactive=True)
with gr.Row():
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{mode}_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index", interactive=True)
steps = gr.Slider(label="Sampling steps", minimum=1, maximum=150, step=1, elem_id=f"{mode}_steps", value=d.steps, interactive=True)
return width, height, prompt, n_prompt, cfg_scale, seed, processing_strength, fix_frame_strength, sampler_index, steps
def inputs_ui():
v2v_args = SimpleNamespace(**V2VArgs())
t2v_args = SimpleNamespace(**T2VArgs())
with gr.Tabs():
glo_sdcn_process_mode = gr.State(value='vid2vid')
with gr.Tab('vid2vid') as tab_vid2vid:
with gr.Row():
gr.HTML('Input video (each frame will be used as initial image for SD and as input image to CN): *REQUIRED')
with gr.Row():
v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file")
v2v_width, v2v_height, v2v_prompt, v2v_n_prompt, v2v_cfg_scale, v2v_seed, v2v_processing_strength, v2v_fix_frame_strength, v2v_sampler_index, v2v_steps = setup_common_values('vid2vid', v2v_args)
with gr.Accordion("Extra settings",open=False):
gr.HTML('# Occlusion mask params:')
with gr.Row():
with gr.Column(scale=1, variant='compact'):
v2v_occlusion_mask_blur = gr.Slider(label='Occlusion blur strength', minimum=0, maximum=10, step=0.1, value=3, interactive=True)
gr.HTML('')
v2v_occlusion_mask_trailing = gr.Checkbox(label="Occlusion trailing", info="Reduce ghosting but adds more flickering to the video", value=True, interactive=True)
with gr.Column(scale=1, variant='compact'):
v2v_occlusion_mask_flow_multiplier = gr.Slider(label='Occlusion flow multiplier', minimum=0, maximum=10, step=0.1, value=5, interactive=True)
v2v_occlusion_mask_difo_multiplier = gr.Slider(label='Occlusion diff origin multiplier', minimum=0, maximum=10, step=0.1, value=2, interactive=True)
v2v_occlusion_mask_difs_multiplier = gr.Slider(label='Occlusion diff styled multiplier', minimum=0, maximum=10, step=0.1, value=0, interactive=True)
with gr.Row():
with gr.Column(scale=1, variant='compact'):
gr.HTML('# Step 1 params:')
v2v_step_1_seed = gr.Number(label='Seed', value = -1, Interactive = True, precision=0)
gr.HTML('<br>')
v2v_step_1_blend_alpha = gr.Slider(label='Warped prev frame vs Current frame blend alpha', minimum=0, maximum=1, step=0.1, value=1, interactive=True)
v2v_step_1_processing_mode = gr.Radio(["Process full image then blend in occlusions", "Inpaint occlusions"], type="index", \
label="Processing mode", value="Process full image then blend in occlusions", interactive=True)
with gr.Column(scale=1, variant='compact'):
gr.HTML('# Step 2 params:')
v2v_step_2_seed = gr.Number(label='Seed', value = 8888, Interactive = True, precision=0)
with FormRow(elem_id="vid2vid_override_settings_row") as row:
v2v_override_settings = create_override_settings_dropdown("vid2vid", row)
with FormGroup(elem_id=f"script_container"):
v2v_custom_inputs = scripts.scripts_img2img.setup_ui()
with gr.Tab('txt2vid') as tab_txt2vid:
with gr.Row():
gr.HTML('Control video (each frame will be used as input image to CN): *NOT REQUIRED')
with gr.Row():
t2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="tex_to_vid_chosen_file")
t2v_init_image = gr.Image(label="Input image", interactive=True, file_count="single", file_types=["image"], elem_id="tex_to_vid_init_image")
t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args)
with gr.Row():
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
gr.HTML('<br>')
t2v_cn_frame_send = gr.Radio(["None", "Current generated frame", "Previous generated frame", "Current reference video frame"], type="index", \
label="What frame should be send to CN?", value="None", interactive=True)
with FormRow(elem_id="txt2vid_override_settings_row") as row:
t2v_override_settings = create_override_settings_dropdown("txt2vid", row)
with FormGroup(elem_id=f"script_container"):
t2v_custom_inputs = scripts.scripts_txt2img.setup_ui()
tab_vid2vid.select(fn=lambda: 'vid2vid', inputs=[], outputs=[glo_sdcn_process_mode])
tab_txt2vid.select(fn=lambda: 'txt2vid', inputs=[], outputs=[glo_sdcn_process_mode])
return locals()
def process(*args):
msg = 'Done'
try:
if args[0] == 'vid2vid':
yield from vid2vid.start_process(*args)
elif args[0] == 'txt2vid':
yield from txt2vid.start_process(*args)
else:
msg = f"Unsupported processing mode: '{args[0]}'"
raise Exception(msg)
except Exception as error:
# handle the exception
msg = f"An exception occurred while trying to process the frame: {error}"
print(msg)
traceback.print_exc()
yield msg, gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Video.update(), gr.Button.update(interactive=True), gr.Button.update(interactive=False)
def stop_process(*args):
utils.shared.is_interrupted = True
return gr.Button.update(interactive=False)
def on_ui_tabs():
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as sdcnanim_interface:
components = {}
#dv = SimpleNamespace(**T2VOutputArgs())
with gr.Row(elem_id='sdcn-core', equal_height=False, variant='compact'):
with gr.Column(scale=1, variant='panel'):
#with gr.Tabs():
components = inputs_ui()
with gr.Accordion("Export settings", open=False):
export_settings_button = gr.Button('Export', elem_id=f"sdcn_export_settings_button")
export_setting_json = gr.Code(value='')
with gr.Column(scale=1, variant='compact'):
with gr.Row(variant='compact'):
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False)
save_frames_check = gr.Checkbox(label="Save frames into a folder nearby a video (check it before running the generation if you also want to save frames separately)", value=True, interactive=True)
gr.HTML('<br>')
with gr.Column(variant="panel"):
sp_progress = gr.HTML(elem_id="sp_progress", value="")
with gr.Row(variant='compact'):
img_preview_curr_frame = gr.Image(label='Current frame', elem_id=f"img_preview_curr_frame", type='pil', height=240)
img_preview_curr_occl = gr.Image(label='Current occlusion', elem_id=f"img_preview_curr_occl", type='pil', height=240)
with gr.Row(variant='compact'):
img_preview_prev_warp = gr.Image(label='Previous frame warped', elem_id=f"img_preview_curr_frame", type='pil', height=240)
img_preview_processed = gr.Image(label='Processed', elem_id=f"img_preview_processed", type='pil', height=240)
video_preview = gr.Video(interactive=False)
with gr.Row(variant='compact'):
dummy_component = gr.Label(visible=False)
components['glo_save_frames_check'] = save_frames_check
# Define parameters for the action methods.
utils.shared.v2v_custom_inputs_size = len(components['v2v_custom_inputs'])
utils.shared.t2v_custom_inputs_size = len(components['t2v_custom_inputs'])
#print('v2v_custom_inputs', len(components['v2v_custom_inputs']), components['v2v_custom_inputs'])
#print('t2v_custom_inputs', len(components['t2v_custom_inputs']), components['t2v_custom_inputs'])
method_inputs = [components[name] for name in utils.get_component_names()] + components['v2v_custom_inputs'] + components['t2v_custom_inputs']
method_outputs = [
sp_progress,
img_preview_curr_frame,
img_preview_curr_occl,
img_preview_prev_warp,
img_preview_processed,
video_preview,
run_button,
stop_button,
]
run_button.click(
fn=process, #wrap_gradio_gpu_call(start_process, extra_outputs=[None, '', '']),
inputs=method_inputs,
outputs=method_outputs,
show_progress=True,
)
stop_button.click(
fn=stop_process,
outputs=[stop_button],
show_progress=False
)
export_settings_button.click(
fn=utils.export_settings,
inputs=method_inputs,
outputs=[export_setting_json],
show_progress=False
)
modules.scripts.scripts_current = None
# define queue - required for generators
sdcnanim_interface.queue()
return [(sdcnanim_interface, "SD-CN-Animation", "sd_cn_animation_interface")]
script_callbacks.on_ui_tabs(on_ui_tabs)

156
scripts/core/flow_utils.py Normal file
View File

@ -0,0 +1,156 @@
import sys, os
import numpy as np
import cv2
from collections import namedtuple
import torch
import argparse
from RAFT.raft import RAFT
from RAFT.utils.utils import InputPadder
import modules.paths as ph
import gc
RAFT_model = None
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
def background_subtractor(frame, fgbg):
fgmask = fgbg.apply(frame)
return cv2.bitwise_and(frame, frame, mask=fgmask)
def RAFT_clear_memory():
global RAFT_model
del RAFT_model
gc.collect()
torch.cuda.empty_cache()
RAFT_model = None
def RAFT_estimate_flow(frame1, frame2, device='cuda'):
global RAFT_model
org_size = frame1.shape[1], frame1.shape[0]
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
frame1 = cv2.resize(frame1, size)
frame2 = cv2.resize(frame2, size)
model_path = ph.models_path + '/RAFT/raft-things.pth'
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
if RAFT_model is None:
args = argparse.Namespace(**{
'model': ph.models_path + '/RAFT/raft-things.pth',
'mixed_precision': True,
'small': False,
'alternate_corr': False,
'path': ""
})
RAFT_model = torch.nn.DataParallel(RAFT(args))
RAFT_model.load_state_dict(torch.load(args.model))
RAFT_model = RAFT_model.module
RAFT_model.to(device)
RAFT_model.eval()
with torch.no_grad():
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
padder = InputPadder(frame1_torch.shape)
image1, image2 = padder.pad(frame1_torch, frame2_torch)
# estimate optical flow
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
next_flow = cv2.resize(next_flow, org_size)
prev_flow = cv2.resize(prev_flow, org_size)
return next_flow, prev_flow, occlusion_mask
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, args_dict):
h, w = cur_frame.shape[:2]
fl_w, fl_h = next_flow.shape[:2]
# normalize flow
next_flow = next_flow / np.array([fl_h,fl_w])
prev_flow = prev_flow / np.array([fl_h,fl_w])
# compute occlusion mask
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow , axis=2)
zero_flow_mask = np.clip(1 - np.linalg.norm(prev_flow, axis=-1)[...,None] * 20, 0, 1)
diff_mask_flow = fb_norm[..., None] * zero_flow_mask
# resize flow
next_flow = cv2.resize(next_flow, (w, h))
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
prev_flow = cv2.resize(prev_flow, (w, h))
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
# Generate sampling grids
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
flow_grid = flow_grid.unsqueeze(0)
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
flow_grid = flow_grid.permute(0, 2, 3, 1)
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
alpha_mask = np.maximum.reduce([diff_mask_flow * args_dict['occlusion_mask_flow_multiplier'] * 10, \
diff_mask_org * args_dict['occlusion_mask_difo_multiplier'], \
diff_mask_stl * args_dict['occlusion_mask_difs_multiplier']]) #
alpha_mask = alpha_mask.repeat(3, axis = -1)
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
if args_dict['occlusion_mask_blur'] > 0:
blur_filter_size = min(w,h) // 15 | 1
alpha_mask = cv2.GaussianBlur(alpha_mask, (blur_filter_size, blur_filter_size) , args_dict['occlusion_mask_blur'], cv2.BORDER_REFLECT)
alpha_mask = np.clip(alpha_mask, 0, 1)
return alpha_mask, warped_frame_styled
def frames_norm(frame): return frame / 127.5 - 1
def flow_norm(flow): return flow / 255
def occl_norm(occl): return occl / 127.5 - 1
def frames_renorm(frame): return (frame + 1) * 127.5
def flow_renorm(flow): return flow * 255
def occl_renorm(occl): return (occl + 1) * 127.5

240
scripts/core/txt2vid.py Normal file
View File

@ -0,0 +1,240 @@
import sys, os
import torch
import gc
import numpy as np
from PIL import Image
import modules.paths as ph
from modules import devices
from scripts.core import utils, flow_utils
from FloweR.model import FloweR
import skimage
import datetime
import cv2
import gradio as gr
import time
FloweR_model = None
DEVICE = 'cpu'
def FloweR_clear_memory():
global FloweR_model
del FloweR_model
gc.collect()
torch.cuda.empty_cache()
FloweR_model = None
def FloweR_load_model(w, h):
global DEVICE, FloweR_model
DEVICE = devices.get_optimal_device()
model_path = ph.models_path + '/FloweR/FloweR_0.1.2.pth'
#remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv' #FloweR_0.1.1.pth
remote_model_path = 'https://drive.google.com/uc?id=1-UYsTXkdUkHLgtPK1Y5_7kKzCgzL_Z6o' #FloweR_0.1.2.pth
if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
FloweR_model = FloweR(input_size = (h, w))
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
# Move the model to the device
FloweR_model = FloweR_model.to(DEVICE)
FloweR_model.eval()
def read_frame_from_video(input_video):
if input_video is None: return None
# Reading video file
if input_video.isOpened():
ret, cur_frame = input_video.read()
if cur_frame is not None:
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
else:
cur_frame = None
input_video.release()
input_video = None
return cur_frame
def start_process(*args):
processing_start_time = time.time()
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
# Open the input video file
input_video = None
if args_dict['file'] is not None:
input_video = cv2.VideoCapture(args_dict['file'].name)
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/txt2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
output_video_folder = os.path.splitext(output_video_name)[0]
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
#if args_dict['save_frames_check']:
os.makedirs(output_video_folder, exist_ok=True)
# Writing to current params to params.json
setts_json = utils.export_settings(*args)
with open(os.path.join(output_video_folder, "params.json"), "w") as outfile:
outfile.write(setts_json)
curr_frame = None
prev_frame = None
def save_result_to_image(image, ind):
if args_dict['save_frames_check']:
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
def set_cn_frame_input():
if args_dict['cn_frame_send'] == 0: # Current generated frame"
pass
elif args_dict['cn_frame_send'] == 1: # Current generated frame"
if curr_frame is not None:
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame), set_references=True)
elif args_dict['cn_frame_send'] == 2: # Previous generated frame
if prev_frame is not None:
utils.set_CNs_input_image(args_dict, Image.fromarray(prev_frame), set_references=True)
elif args_dict['cn_frame_send'] == 3: # Current reference video frame
if input_video is not None:
curr_video_frame = read_frame_from_video(input_video)
curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height']))
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame), set_references=True)
else:
raise Exception('There is no input video! Set it up first.')
else:
raise Exception('Incorrect cn_frame_send mode!')
set_cn_frame_input()
if args_dict['init_image'] is not None:
#resize array to args_dict['width'], args_dict['height']
image_array=args_dict['init_image']#this is a numpy array
init_frame = np.array(Image.fromarray(image_array).resize((args_dict['width'], args_dict['height'])).convert('RGB'))
processed_frame = init_frame.copy()
else:
processed_frames, _, _, _ = utils.txt2img(args_dict)
processed_frame = np.array(processed_frames[0])[...,:3]
#if input_video is not None:
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
init_frame = processed_frame.copy()
output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height']))
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
stat = f"Frame: 1 / {args_dict['length']}; " + utils.get_time_left(1, args_dict['length'], processing_start_time)
utils.shared.is_interrupted = False
save_result_to_image(processed_frame, 1)
yield stat, init_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
org_size = args_dict['width'], args_dict['height']
size = args_dict['width'] // 128 * 128, args_dict['height'] // 128 * 128
FloweR_load_model(size[0], size[1])
clip_frames = np.zeros((4, size[1], size[0], 3), dtype=np.uint8)
prev_frame = init_frame
for ind in range(args_dict['length'] - 1):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('t2v', args_dict)
clip_frames = np.roll(clip_frames, -1, axis=0)
clip_frames[-1] = cv2.resize(prev_frame[...,:3], size)
clip_frames_torch = flow_utils.frames_norm(torch.from_numpy(clip_frames).to(DEVICE, dtype=torch.float32))
with torch.no_grad():
pred_data = FloweR_model(clip_frames_torch.unsqueeze(0))[0]
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
pred_next = flow_utils.frames_renorm(pred_data[...,3:6]).cpu().numpy()
pred_occl = np.clip(pred_occl * 10, 0, 255).astype(np.uint8)
pred_next = np.clip(pred_next, 0, 255).astype(np.uint8)
pred_flow = cv2.resize(pred_flow, org_size)
pred_occl = cv2.resize(pred_occl, org_size)
pred_next = cv2.resize(pred_next, org_size)
curr_frame = pred_next.copy()
'''
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
pred_occl = cv2.GaussianBlur(pred_occl, (21,21), 2, cv2.BORDER_REFLECT_101)
pred_occl = (np.abs(pred_occl / 255) ** 1.5) * 255
pred_occl = np.clip(pred_occl * 25, 0, 255).astype(np.uint8)
flow_map = pred_flow.copy()
flow_map[:,:,0] += np.arange(args_dict['width'])
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT_101)
alpha_mask = pred_occl / 255.
#alpha_mask = np.clip(alpha_mask + np.random.normal(0, 0.4, size = alpha_mask.shape), 0, 1)
curr_frame = pred_next.astype(float) * alpha_mask + warped_frame.astype(float) * (1 - alpha_mask)
curr_frame = np.clip(curr_frame, 0, 255).astype(np.uint8)
#curr_frame = warped_frame.copy()
'''
set_cn_frame_input()
args_dict['mode'] = 4
args_dict['init_img'] = Image.fromarray(pred_next)
args_dict['mask_img'] = Image.fromarray(pred_occl)
args_dict['seed'] = -1
args_dict['denoising_strength'] = args_dict['processing_strength']
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])[...,:3]
#if input_video is not None:
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
#else:
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
args_dict['mode'] = 0
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['mask_img'] = None
args_dict['seed'] = -1
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
#utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])[...,:3]
#if input_video is not None:
# processed_frame = skimage.exposure.match_histograms(processed_frame, curr_video_frame, channel_axis=-1)
#else:
processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
prev_frame = processed_frame.copy()
save_result_to_image(processed_frame, ind + 2)
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
yield stat, curr_frame, pred_occl, pred_next, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
if input_video is not None: input_video.release()
output_video.release()
FloweR_clear_memory()
curr_frame = gr.Image.update()
occlusion_mask = gr.Image.update()
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
# print('TOTAL TIME:', int(time.time() - processing_start_time))
yield 'done', curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)

432
scripts/core/utils.py Normal file
View File

@ -0,0 +1,432 @@
class shared:
is_interrupted = False
v2v_custom_inputs_size = 0
t2v_custom_inputs_size = 0
def get_component_names():
components_list = [
'glo_sdcn_process_mode',
'v2v_file', 'v2v_width', 'v2v_height', 'v2v_prompt', 'v2v_n_prompt', 'v2v_cfg_scale', 'v2v_seed', 'v2v_processing_strength', 'v2v_fix_frame_strength',
'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings',
'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier',
'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed',
't2v_file','t2v_init_image', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 't2v_cn_frame_send',
'glo_save_frames_check'
]
return components_list
def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
args_list = get_component_names()
# set default values for params that were not specified
args_dict = {
# video to video params
'v2v_mode': 0,
'v2v_prompt': '',
'v2v_n_prompt': '',
'v2v_prompt_styles': [],
'v2v_init_video': None, # Always required
'v2v_steps': 15,
'v2v_sampler_index': 0, # 'Euler a'
'v2v_mask_blur': 0,
'v2v_inpainting_fill': 1, # original
'v2v_restore_faces': False,
'v2v_tiling': False,
'v2v_n_iter': 1,
'v2v_batch_size': 1,
'v2v_cfg_scale': 5.5,
'v2v_image_cfg_scale': 1.5,
'v2v_denoising_strength': 0.75,
'v2v_processing_strength': 0.85,
'v2v_fix_frame_strength': 0.15,
'v2v_seed': -1,
'v2v_subseed': -1,
'v2v_subseed_strength': 0,
'v2v_seed_resize_from_h': 512,
'v2v_seed_resize_from_w': 512,
'v2v_seed_enable_extras': False,
'v2v_height': 512,
'v2v_width': 512,
'v2v_resize_mode': 1,
'v2v_inpaint_full_res': True,
'v2v_inpaint_full_res_padding': 0,
'v2v_inpainting_mask_invert': False,
# text to video params
't2v_mode': 4,
't2v_prompt': '',
't2v_n_prompt': '',
't2v_prompt_styles': [],
't2v_init_img': None,
't2v_mask_img': None,
't2v_steps': 15,
't2v_sampler_index': 0, # 'Euler a'
't2v_mask_blur': 0,
't2v_inpainting_fill': 1, # original
't2v_restore_faces': False,
't2v_tiling': False,
't2v_n_iter': 1,
't2v_batch_size': 1,
't2v_cfg_scale': 5.5,
't2v_image_cfg_scale': 1.5,
't2v_denoising_strength': 0.75,
't2v_processing_strength': 0.85,
't2v_fix_frame_strength': 0.15,
't2v_seed': -1,
't2v_subseed': -1,
't2v_subseed_strength': 0,
't2v_seed_resize_from_h': 512,
't2v_seed_resize_from_w': 512,
't2v_seed_enable_extras': False,
't2v_height': 512,
't2v_width': 512,
't2v_resize_mode': 1,
't2v_inpaint_full_res': True,
't2v_inpaint_full_res_padding': 0,
't2v_inpainting_mask_invert': False,
't2v_override_settings': [],
#'t2v_script_inputs': [0],
't2v_fps': 12,
}
args = list(args)
for i in range(len(args_list)):
if (args[i] is None) and (args_list[i] in args_dict):
#args[i] = args_dict[args_list[i]]
pass
else:
args_dict[args_list[i]] = args[i]
args_dict['v2v_script_inputs'] = args[len(args_list):len(args_list)+shared.v2v_custom_inputs_size]
#print('v2v_script_inputs', args_dict['v2v_script_inputs'])
args_dict['t2v_script_inputs'] = args[len(args_list)+shared.v2v_custom_inputs_size:]
#print('t2v_script_inputs', args_dict['t2v_script_inputs'])
return args_dict
def get_mode_args(mode, args_dict):
mode_args_dict = {}
for key, value in args_dict.items():
if key[:3] in [mode, 'glo'] :
mode_args_dict[key[4:]] = value
return mode_args_dict
def set_CNs_input_image(args_dict, image, set_references = False):
for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit':
if script_input.module not in ["reference_only", "reference_adain", "reference_adain+attn"] or set_references:
script_input.image = np.array(image)
script_input.batch_images = [np.array(image)]
import time
import datetime
def get_time_left(ind, length, processing_start_time):
s_passed = int(time.time() - processing_start_time)
time_passed = datetime.timedelta(seconds=s_passed)
s_left = int(s_passed / ind * (length - ind))
time_left = datetime.timedelta(seconds=s_left)
return f"Time elapsed: {time_passed}; Time left: {time_left};"
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from types import SimpleNamespace
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, process_images
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
from modules.shared import opts, state
from modules import devices, sd_samplers, img2img
from modules import shared, sd_hijack
# TODO: Refactor all the code below
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p)
#images = shared.listfiles(input_dir)
images = [input_img]
is_inpaint_batch = False
#if inpaint_mask_dir:
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
# is_inpaint_batch = len(inpaint_masks) > 0
#if is_inpaint_batch:
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True
p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
generated_images = []
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
if state.interrupted:
break
img = image #Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
#if is_inpaint_batch:
# # try to find corresponding mask for an image using simple filename matching
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# # if not found use first one ("same mask for all images" use-case)
# if not mask_image_path in inpaint_masks:
# mask_image_path = inpaint_masks[0]
# mask_image = Image.open(mask_image_path)
# p.image_mask = mask_image
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
generated_images.append(proc.images[0])
#for n, processed_image in enumerate(proc.images):
# filename = os.path.basename(image)
# if n > 0:
# left, right = os.path.splitext(filename)
# filename = f"{left}-{n}{right}"
# if not save_normally:
# os.makedirs(output_dir, exist_ok=True)
# if processed_image.mode == 'RGBA':
# processed_image = processed_image.convert("RGB")
# processed_image.save(os.path.join(output_dir, filename))
return generated_images
def img2img(args_dict):
args = SimpleNamespace(**args_dict)
override_settings = create_override_settings_dict(args.override_settings)
is_batch = args.mode == 5
if args.mode == 0: # img2img
image = args.init_img.convert("RGB")
mask = None
elif args.mode == 1: # img2img sketch
image = args.sketch.convert("RGB")
mask = None
elif args.mode == 2: # inpaint
image, mask = args.init_img_with_mask["image"], args.init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert("RGB")
elif args.mode == 3: # inpaint sketch
image = args.inpaint_color_sketch
orig = args.inpaint_color_sketch_orig or args.inpaint_color_sketch
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - args.mask_alpha / 100)
blur = ImageFilter.GaussianBlur(args.mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif args.mode == 4: # inpaint upload mask
#image = args.init_img_inpaint
#mask = args.init_mask_inpaint
image = args.init_img.convert("RGB")
mask = args.mask_img.convert("L")
else:
image = None
mask = None
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= args.denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=args.prompt,
negative_prompt=args.n_prompt,
styles=args.prompt_styles,
seed=args.seed,
subseed=args.subseed,
subseed_strength=args.subseed_strength,
seed_resize_from_h=args.seed_resize_from_h,
seed_resize_from_w=args.seed_resize_from_w,
seed_enable_extras=args.seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[args.sampler_index].name,
batch_size=args.batch_size,
n_iter=args.n_iter,
steps=args.steps,
cfg_scale=args.cfg_scale,
width=args.width,
height=args.height,
restore_faces=args.restore_faces,
tiling=args.tiling,
init_images=[image],
mask=mask,
mask_blur=args.mask_blur,
inpainting_fill=args.inpainting_fill,
resize_mode=args.resize_mode,
denoising_strength=args.denoising_strength,
image_cfg_scale=args.image_cfg_scale,
inpaint_full_res=args.inpaint_full_res,
inpaint_full_res_padding=args.inpaint_full_res_padding,
inpainting_mask_invert=args.inpainting_mask_invert,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_img2img
p.script_args = args.script_inputs
#if shared.cmd_opts.enable_console_prompts:
# print(f"\nimg2img: {args.prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = args.mask_blur
'''
if is_batch:
...
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
# process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args.script_inputs)
# processed = Processed(p, [], p.seed, "")
else:
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
'''
generated_images = process_img(p, image, None, '', args.script_inputs)
processed = Processed(p, [], p.seed, "")
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
#if opts.samples_log_stdout:
# print(generation_info_js)
#if opts.do_not_show_images:
# processed.images = []
#print(generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments))
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
def txt2img(args_dict):
args = SimpleNamespace(**args_dict)
override_settings = create_override_settings_dict(args.override_settings)
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=args.prompt,
styles=args.prompt_styles,
negative_prompt=args.n_prompt,
seed=args.seed,
subseed=args.subseed,
subseed_strength=args.subseed_strength,
seed_resize_from_h=args.seed_resize_from_h,
seed_resize_from_w=args.seed_resize_from_w,
seed_enable_extras=args.seed_enable_extras,
sampler_name=sd_samplers.samplers[args.sampler_index].name,
batch_size=args.batch_size,
n_iter=args.n_iter,
steps=args.steps,
cfg_scale=args.cfg_scale,
width=args.width,
height=args.height,
restore_faces=args.restore_faces,
tiling=args.tiling,
#enable_hr=args.enable_hr,
#denoising_strength=args.denoising_strength if enable_hr else None,
#hr_scale=hr_scale,
#hr_upscaler=hr_upscaler,
#hr_second_pass_steps=hr_second_pass_steps,
#hr_resize_x=hr_resize_x,
#hr_resize_y=hr_resize_y,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args.script_inputs
#if cmd_opts.enable_console_prompts:
# print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
#if opts.samples_log_stdout:
# print(generation_info_js)
#if opts.do_not_show_images:
# processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
import json
def get_json(obj):
return json.loads(
json.dumps(obj, default=lambda o: getattr(o, '__dict__', str(o)))
)
def export_settings(*args):
args_dict = args_to_dict(*args)
if args[0] == 'vid2vid':
args_dict = get_mode_args('v2v', args_dict)
elif args[0] == 'txt2vid':
args_dict = get_mode_args('t2v', args_dict)
else:
msg = f"Unsupported processing mode: '{args[0]}'"
raise Exception(msg)
# convert CN params into a readable dict
cn_remove_list = ['low_vram', 'is_ui', 'input_mode', 'batch_images', 'output_dir', 'loopback', 'image']
args_dict['ControlNets'] = []
for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit':
cn_values_dict = get_json(script_input)
if cn_values_dict['enabled']:
for key in cn_remove_list:
if key in cn_values_dict: del cn_values_dict[key]
args_dict['ControlNets'].append(cn_values_dict)
# remove unimportant values
remove_list = ['save_frames_check', 'restore_faces', 'prompt_styles', 'mask_blur', 'inpainting_fill', 'tiling', 'n_iter', 'batch_size', 'subseed', 'subseed_strength', 'seed_resize_from_h', \
'seed_resize_from_w', 'seed_enable_extras', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'file', 'denoising_strength', \
'override_settings', 'script_inputs', 'init_img', 'mask_img', 'mode', 'init_video']
for key in remove_list:
if key in args_dict: del args_dict[key]
return json.dumps(args_dict, indent=2, default=lambda o: getattr(o, '__dict__', str(o)))

275
scripts/core/vid2vid.py Normal file
View File

@ -0,0 +1,275 @@
import sys, os
import math
import os
import sys
import traceback
import numpy as np
from PIL import Image
from modules import devices, sd_samplers
from modules import shared, sd_hijack
try:
from modules import lowvram
except ImportError:
lowvram = None
import modules.shared as shared
import gc
import cv2
import gradio as gr
import time
import skimage
import datetime
from scripts.core.flow_utils import RAFT_estimate_flow, RAFT_clear_memory, compute_diff_map
from scripts.core import utils
class sdcn_anim_tmp:
prepear_counter = 0
process_counter = 0
input_video = None
output_video = None
curr_frame = None
prev_frame = None
prev_frame_styled = None
prev_frame_alpha_mask = None
fps = None
total_frames = None
prepared_frames = None
prepared_next_flows = None
prepared_prev_flows = None
frames_prepared = False
def read_frame_from_video():
# Reading video file
if sdcn_anim_tmp.input_video.isOpened():
ret, cur_frame = sdcn_anim_tmp.input_video.read()
if cur_frame is not None:
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
else:
cur_frame = None
sdcn_anim_tmp.input_video.release()
return cur_frame
def get_cur_stat():
stat = f'Frames prepared: {sdcn_anim_tmp.prepear_counter + 1} / {sdcn_anim_tmp.total_frames}; '
stat += f'Frames processed: {sdcn_anim_tmp.process_counter + 1} / {sdcn_anim_tmp.total_frames}; '
return stat
def clear_memory_from_sd():
if shared.sd_model is not None:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
if lowvram:
try:
lowvram.send_everything_to_cpu()
except Exception as e:
...
del shared.sd_model
shared.sd_model = None
gc.collect()
devices.torch_gc()
def start_process(*args):
processing_start_time = time.time()
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
sdcn_anim_tmp.process_counter = 0
sdcn_anim_tmp.prepear_counter = 0
# Open the input video file
sdcn_anim_tmp.input_video = cv2.VideoCapture(args_dict['file'].name)
# Get useful info from the source video
sdcn_anim_tmp.fps = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FPS))
sdcn_anim_tmp.total_frames = int(sdcn_anim_tmp.input_video.get(cv2.CAP_PROP_FRAME_COUNT))
loop_iterations = (sdcn_anim_tmp.total_frames-1) * 2
# Create an output video file with the same fps, width, and height as the input video
output_video_name = f'outputs/sd-cn-animation/vid2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4'
output_video_folder = os.path.splitext(output_video_name)[0]
os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
if args_dict['save_frames_check']:
os.makedirs(output_video_folder, exist_ok=True)
def save_result_to_image(image, ind):
if args_dict['save_frames_check']:
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
sdcn_anim_tmp.output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), sdcn_anim_tmp.fps, (args_dict['width'], args_dict['height']))
curr_frame = read_frame_from_video()
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
sdcn_anim_tmp.prepared_frames = np.zeros((11, args_dict['height'], args_dict['width'], 3), dtype=np.uint8)
sdcn_anim_tmp.prepared_next_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_frames[0] = curr_frame
args_dict['init_img'] = Image.fromarray(curr_frame)
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])[...,:3]
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
#print('Processed frame ', 0)
sdcn_anim_tmp.curr_frame = curr_frame
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
utils.shared.is_interrupted = False
save_result_to_image(processed_frame, 1)
stat = get_cur_stat() + utils.get_time_left(1, loop_iterations, processing_start_time)
yield stat, sdcn_anim_tmp.curr_frame, None, None, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
for step in range(loop_iterations):
if utils.shared.is_interrupted: break
args_dict = utils.args_to_dict(*args)
args_dict = utils.get_mode_args('v2v', args_dict)
occlusion_mask = None
prev_frame = None
curr_frame = sdcn_anim_tmp.curr_frame
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
prepare_steps = 10
if sdcn_anim_tmp.process_counter % prepare_steps == 0 and not sdcn_anim_tmp.frames_prepared: # prepare next 10 frames for processing
#clear_memory_from_sd()
device = devices.get_optimal_device()
curr_frame = read_frame_from_video()
if curr_frame is not None:
curr_frame = cv2.resize(curr_frame, (args_dict['width'], args_dict['height']))
prev_frame = sdcn_anim_tmp.prev_frame.copy()
next_flow, prev_flow, occlusion_mask = RAFT_estimate_flow(prev_frame, curr_frame, device=device)
occlusion_mask = np.clip(occlusion_mask * 0.1 * 255, 0, 255).astype(np.uint8)
cn = sdcn_anim_tmp.prepear_counter % 10
if sdcn_anim_tmp.prepear_counter % 10 == 0:
sdcn_anim_tmp.prepared_frames[cn] = sdcn_anim_tmp.prev_frame
sdcn_anim_tmp.prepared_frames[cn + 1] = curr_frame.copy()
sdcn_anim_tmp.prepared_next_flows[cn] = next_flow.copy()
sdcn_anim_tmp.prepared_prev_flows[cn] = prev_flow.copy()
#print('Prepared frame ', cn+1)
sdcn_anim_tmp.prev_frame = curr_frame.copy()
sdcn_anim_tmp.prepear_counter += 1
if sdcn_anim_tmp.prepear_counter % prepare_steps == 0 or \
sdcn_anim_tmp.prepear_counter >= sdcn_anim_tmp.total_frames - 1 or \
curr_frame is None:
# Remove RAFT from memory
RAFT_clear_memory()
sdcn_anim_tmp.frames_prepared = True
else:
# process frame
sdcn_anim_tmp.frames_prepared = False
cn = sdcn_anim_tmp.process_counter % 10
curr_frame = sdcn_anim_tmp.prepared_frames[cn+1][...,:3]
prev_frame = sdcn_anim_tmp.prepared_frames[cn][...,:3]
next_flow = sdcn_anim_tmp.prepared_next_flows[cn]
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]
### STEP 1
alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled, args_dict)
warped_styled_frame_ = warped_styled_frame.copy()
#fl_w, fl_h = prev_flow.shape[:2]
#prev_flow_n = prev_flow / np.array([fl_h,fl_w])
#flow_mask = np.clip(1 - np.linalg.norm(prev_flow_n, axis=-1)[...,None] * 20, 0, 1)
#alpha_mask = alpha_mask * flow_mask
if sdcn_anim_tmp.process_counter > 0 and args_dict['occlusion_mask_trailing']:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.round(alpha_mask * 8) / 8 #> 0.3
alpha_mask = np.clip(alpha_mask, 0, 1)
occlusion_mask = np.clip(alpha_mask * 255, 0, 255).astype(np.uint8)
# fix warped styled frame from duplicated that occures on the places where flow is zero, but only because there is no place to get the color from
warped_styled_frame = curr_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
# process current frame
# TODO: convert args_dict into separate dict that stores only params necessery for img2img processing
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
img2img_args_dict['denoising_strength'] = args_dict['processing_strength']
if args_dict['step_1_processing_mode'] == 0: # Process full image then blend in occlusions
img2img_args_dict['mode'] = 0
img2img_args_dict['mask_img'] = None #Image.fromarray(occlusion_mask)
elif args_dict['step_1_processing_mode'] == 1: # Inpaint occlusions
img2img_args_dict['mode'] = 4
img2img_args_dict['mask_img'] = Image.fromarray(occlusion_mask)
else:
raise Exception('Incorrect step 1 processing mode!')
blend_alpha = args_dict['step_1_blend_alpha']
init_img = warped_styled_frame * (1 - blend_alpha) + curr_frame * blend_alpha
img2img_args_dict['init_img'] = Image.fromarray(np.clip(init_img, 0, 255).astype(np.uint8))
img2img_args_dict['seed'] = args_dict['step_1_seed']
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
processed_frame = np.array(processed_frames[0])[...,:3]
# normalizing the colors
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
processed_frame = processed_frame.astype(float) * alpha_mask + warped_styled_frame.astype(float) * (1 - alpha_mask)
#processed_frame = processed_frame * 0.94 + curr_frame * 0.06
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
### STEP 2
if args_dict['fix_frame_strength'] > 0:
img2img_args_dict = args_dict #copy.deepcopy(args_dict)
img2img_args_dict['mode'] = 0
img2img_args_dict['init_img'] = Image.fromarray(processed_frame)
img2img_args_dict['mask_img'] = None
img2img_args_dict['denoising_strength'] = args_dict['fix_frame_strength']
img2img_args_dict['seed'] = args_dict['step_2_seed']
utils.set_CNs_input_image(img2img_args_dict, Image.fromarray(curr_frame))
processed_frames, _, _, _ = utils.img2img(img2img_args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, channel_axis=None)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
warped_styled_frame_ = np.clip(warped_styled_frame_, 0, 255).astype(np.uint8)
# Write the frame to the output video
frame_out = np.clip(processed_frame, 0, 255).astype(np.uint8)
frame_out = cv2.cvtColor(frame_out, cv2.COLOR_RGB2BGR)
sdcn_anim_tmp.output_video.write(frame_out)
sdcn_anim_tmp.process_counter += 1
#if sdcn_anim_tmp.process_counter >= sdcn_anim_tmp.total_frames - 1:
# sdcn_anim_tmp.input_video.release()
# sdcn_anim_tmp.output_video.release()
# sdcn_anim_tmp.prev_frame = None
save_result_to_image(processed_frame, sdcn_anim_tmp.process_counter + 1)
stat = get_cur_stat() + utils.get_time_left(step+2, loop_iterations+1, processing_start_time)
yield stat, curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
RAFT_clear_memory()
sdcn_anim_tmp.input_video.release()
sdcn_anim_tmp.output_video.release()
curr_frame = gr.Image.update()
occlusion_mask = gr.Image.update()
warped_styled_frame_ = gr.Image.update()
processed_frame = gr.Image.update()
yield get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame_, processed_frame, output_video_name, gr.Button.update(interactive=True), gr.Button.update(interactive=False)