better control for txt2vid

pull/212/head
Alexey Borsky 2023-05-30 19:49:03 +03:00
parent 111711fc7b
commit 2e257bbfc3
7 changed files with 180 additions and 85 deletions

View File

@ -10,7 +10,13 @@ class FloweR(nn.Module):
self.input_size = input_size self.input_size = input_size
self.window_size = window_size self.window_size = window_size
#INPUT: 384 x 384 x 10 * 3 # 2 channels for optical flow
# 1 channel for occlusion mask
# 3 channels for next frame prediction
self.out_channels = 6
#INPUT: 384 x 384 x 4 * 3
### DOWNSCALE ### ### DOWNSCALE ###
self.conv_block_1 = nn.Sequential( self.conv_block_1 = nn.Sequential(
@ -18,47 +24,50 @@ class FloweR(nn.Module):
nn.ReLU(), nn.ReLU(),
) # 384 x 384 x 128 ) # 384 x 384 x 128
self.conv_block_2 = nn.Sequential( # x128 self.conv_block_2 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 192 x 192 x 128 ) # 192 x 192 x 128
self.conv_block_3 = nn.Sequential( # x64 self.conv_block_3 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 96 x 96 x 128 ) # 96 x 96 x 128
self.conv_block_4 = nn.Sequential( # x32 self.conv_block_4 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 48 x 48 x 128 ) # 48 x 48 x 128
self.conv_block_5 = nn.Sequential( # x16 self.conv_block_5 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 24 x 24 x 128 ) # 24 x 24 x 128
self.conv_block_6 = nn.Sequential( # x8 self.conv_block_6 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 12 x 12 x 128 ) # 12 x 12 x 128
self.conv_block_7 = nn.Sequential( # x4 self.conv_block_7 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 6 x 6 x 128 ) # 6 x 6 x 128
self.conv_block_8 = nn.Sequential( # x2 self.conv_block_8 = nn.Sequential(
nn.AvgPool2d(2), nn.AvgPool2d(2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same'),
nn.ReLU(), nn.ReLU(),
) # 3 x 3 x 128 ) # 3 x 3 x 128 - 9 input tokens
### Transformer part ###
# To be done
### UPSCALE ### ### UPSCALE ###
self.conv_block_9 = nn.Sequential( self.conv_block_9 = nn.Sequential(
@ -103,17 +112,19 @@ class FloweR(nn.Module):
nn.ReLU(), nn.ReLU(),
) # 384 x 384 x 128 ) # 384 x 384 x 128
self.conv_block_16 = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding='same') self.conv_block_16 = nn.Conv2d(128, self.out_channels, kernel_size=3, stride=1, padding='same')
def forward(self, x): def forward(self, input_frames):
if x.size(1) != self.window_size:
if input_frames.size(1) != self.window_size:
raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.') raise Exception(f'Shape of the input is not compatable. There should be exactly {self.window_size} frames in an input video.')
h, w = self.input_size
# batch, frames, height, width, colors # batch, frames, height, width, colors
in_x = x.permute((0, 1, 4, 2, 3)) input_frames_permuted = input_frames.permute((0, 1, 4, 2, 3))
# batch, frames, colors, height, width # batch, frames, colors, height, width
in_x = in_x.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1]) in_x = input_frames_permuted.reshape(-1, self.window_size * 3, self.input_size[0], self.input_size[1])
### DOWNSCALE ### ### DOWNSCALE ###
block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128 block_1_out = self.conv_block_1(in_x) # 384 x 384 x 128
@ -134,10 +145,47 @@ class FloweR(nn.Module):
block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128 block_14_out = block_2_out + self.conv_block_14(block_13_out) # 192 x 192 x 128
block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128 block_15_out = block_1_out + self.conv_block_15(block_14_out) # 384 x 384 x 128
block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1) block_16_out = self.conv_block_16(block_15_out) # 384 x 384 x (2 + 1 + 3)
out = block_16_out.reshape(-1, 3, self.input_size[0], self.input_size[1]) out = block_16_out.reshape(-1, self.out_channels, self.input_size[0], self.input_size[1])
# batch, colors, height, width ### for future model training ###
out = out.permute((0, 2, 3, 1)) device = out.get_device()
# batch, height, width, colors
return out pred_flow = out[:,:2,:,:] * 255 # (-255, 255)
pred_occl = (out[:,2:3,:,:] + 1) / 2 # [0, 1]
pred_next = out[:,3:6,:,:]
# Generate sampling grids
# Create grid to upsample input
'''
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) '''
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
flow_grid = flow_grid.unsqueeze(0).to(device=device)
flow_grid = flow_grid + pred_flow
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
# batch, flow_chanels, height, width
flow_grid = flow_grid.permute(0, 2, 3, 1)
# batch, height, width, flow_chanels
previous_frame = input_frames_permuted[:, -1, :, :, :]
sampling_mode = "bilinear" if self.training else "nearest"
warped_frame = torch.nn.functional.grid_sample(previous_frame, flow_grid, mode=sampling_mode, padding_mode="reflection", align_corners=False)
alpha_mask = torch.clip(pred_occl * 10, 0, 1) * 0.04
pred_next = torch.clip(pred_next, -1, 1)
warped_frame = torch.clip(warped_frame, -1, 1)
next_frame = pred_next * alpha_mask + warped_frame * (1 - alpha_mask)
res = torch.cat((pred_flow / 255, pred_occl * 2 - 1, next_frame), dim=1)
# batch, channels, height, width
res = res.permute((0, 2, 3, 1))
# batch, height, width, channels
return res

View File

@ -83,4 +83,7 @@ pip install scikit-image==0.19.2 --no-cache-dir
* ControlNet with preprocessers like "reference_only", "reference_adain", "reference_adain+attn" are not reseted with video frames to have an ability to control style of the video. * ControlNet with preprocessers like "reference_only", "reference_adain", "reference_adain+attn" are not reseted with video frames to have an ability to control style of the video.
* Fixed an issue because of witch 'processing_strength' UI parameters does not actually affected denoising strength at the fist processing step. * Fixed an issue because of witch 'processing_strength' UI parameters does not actually affected denoising strength at the fist processing step.
* Fixed issue #112. It will not try to reinstall requirements at every start of webui. * Fixed issue #112. It will not try to reinstall requirements at every start of webui.
* Some improvements in text 2 video method.
* Parameters used to generated a video now automatically saved in video's folder.
* Added ability to control what frame will be send to CN in text to video mode.
--> -->

View File

@ -1 +1 @@
scikit-image>=0.19.2 scikit-image

View File

@ -37,8 +37,8 @@ def T2VArgs():
cfg_scale = 5.5 cfg_scale = 5.5
steps = 15 steps = 15
prompt = "" prompt = ""
n_prompt = "text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" n_prompt = "((blur, blurr, blurred, blurry, fuzzy, unclear, unfocus, bocca effect)), text, letters, logo, brand, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
processing_strength = 0.85 processing_strength = 0.75
fix_frame_strength = 0.35 fix_frame_strength = 0.35
return locals() return locals()
@ -122,6 +122,10 @@ def inputs_ui():
t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True) t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True)
t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True) t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True)
gr.HTML('<br>')
t2v_cn_frame_send = gr.Radio(["None", "Current generated frame", "Previous generated frame", "Current reference video frame"], type="index", \
label="What frame should be send to CN?", value="None", interactive=True)
with FormRow(elem_id="txt2vid_override_settings_row") as row: with FormRow(elem_id="txt2vid_override_settings_row") as row:
t2v_override_settings = create_override_settings_dropdown("txt2vid", row) t2v_override_settings = create_override_settings_dropdown("txt2vid", row)
@ -155,43 +159,7 @@ def stop_process(*args):
utils.shared.is_interrupted = True utils.shared.is_interrupted = True
return gr.Button.update(interactive=False) return gr.Button.update(interactive=False)
import json
def get_json(obj):
return json.loads(
json.dumps(obj, default=lambda o: getattr(o, '__dict__', str(o)))
)
def export_settings(*args):
args_dict = utils.args_to_dict(*args)
if args[0] == 'vid2vid':
args_dict = utils.get_mode_args('v2v', args_dict)
elif args[0] == 'txt2vid':
args_dict = utils.get_mode_args('t2v', args_dict)
else:
msg = f"Unsupported processing mode: '{args[0]}'"
raise Exception(msg)
# convert CN params into a readable dict
cn_remove_list = ['low_vram', 'is_ui', 'input_mode', 'batch_images', 'output_dir', 'loopback', 'image']
args_dict['ControlNets'] = []
for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit':
cn_values_dict = get_json(script_input)
if cn_values_dict['enabled']:
for key in cn_remove_list:
if key in cn_values_dict: del cn_values_dict[key]
args_dict['ControlNets'].append(cn_values_dict)
# remove unimportant values
remove_list = ['save_frames_check', 'restore_faces', 'prompt_styles', 'mask_blur', 'inpainting_fill', 'tiling', 'n_iter', 'batch_size', 'subseed', 'subseed_strength', 'seed_resize_from_h', \
'seed_resize_from_w', 'seed_enable_extras', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'file', 'denoising_strength', \
'override_settings', 'script_inputs', 'init_img', 'mask_img', 'mode', 'init_video']
for key in remove_list:
if key in args_dict: del args_dict[key]
return json.dumps(args_dict, indent=2, default=lambda o: getattr(o, '__dict__', str(o)))
def on_ui_tabs(): def on_ui_tabs():
modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_current = modules.scripts.scripts_img2img
@ -216,7 +184,7 @@ def on_ui_tabs():
run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary') run_button = gr.Button('Generate', elem_id=f"sdcn_anim_generate", variant='primary')
stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False) stop_button = gr.Button('Interrupt', elem_id=f"sdcn_anim_interrupt", variant='primary', interactive=False)
save_frames_check = gr.Checkbox(label="Save frames into a folder nearby a video (check it before running the generation if you also want to save frames separately)", value=False, interactive=True) save_frames_check = gr.Checkbox(label="Save frames into a folder nearby a video (check it before running the generation if you also want to save frames separately)", value=True, interactive=True)
gr.HTML('<br>') gr.HTML('<br>')
with gr.Column(variant="panel"): with gr.Column(variant="panel"):
@ -268,7 +236,7 @@ def on_ui_tabs():
) )
export_settings_button.click( export_settings_button.click(
fn=export_settings, fn=utils.export_settings,
inputs=method_inputs, inputs=method_inputs,
outputs=[export_setting_json], outputs=[export_setting_json],
show_progress=False show_progress=False

View File

@ -116,8 +116,8 @@ def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_sty
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy() warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection").permute(0, 2, 3, 1)[0].numpy() warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT) #warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT) #warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
@ -143,12 +143,14 @@ def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_sty
return alpha_mask, warped_frame_styled return alpha_mask, warped_frame_styled
def frames_norm(occl): return occl / 127.5 - 1 def frames_norm(frame): return frame / 127.5 - 1
def flow_norm(flow): return flow / 255 def flow_norm(flow): return flow / 255
def occl_norm(occl): return occl / 127.5 - 1 def occl_norm(occl): return occl / 127.5 - 1
def frames_renorm(frame): return (frame + 1) * 127.5
def flow_renorm(flow): return flow * 255 def flow_renorm(flow): return flow * 255
def occl_renorm(occl): return (occl + 1) * 127.5 def occl_renorm(occl): return (occl + 1) * 127.5

View File

@ -30,8 +30,9 @@ def FloweR_load_model(w, h):
global DEVICE, FloweR_model global DEVICE, FloweR_model
DEVICE = devices.get_optimal_device() DEVICE = devices.get_optimal_device()
model_path = ph.models_path + '/FloweR/FloweR_0.1.1.pth' model_path = ph.models_path + '/FloweR/FloweR_0.1.2.pth'
remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv' #remote_model_path = 'https://drive.google.com/uc?id=1K7gXUosgxU729_l-osl1HBU5xqyLsALv' #FloweR_0.1.1.pth
remote_model_path = 'https://drive.google.com/uc?id=1-UYsTXkdUkHLgtPK1Y5_7kKzCgzL_Z6o' #FloweR_0.1.2.pth
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
@ -43,6 +44,7 @@ def FloweR_load_model(w, h):
FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE)) FloweR_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
# Move the model to the device # Move the model to the device
FloweR_model = FloweR_model.to(DEVICE) FloweR_model = FloweR_model.to(DEVICE)
FloweR_model.eval()
def read_frame_from_video(input_video): def read_frame_from_video(input_video):
if input_video is None: return None if input_video is None: return None
@ -74,17 +76,41 @@ def start_process(*args):
output_video_folder = os.path.splitext(output_video_name)[0] output_video_folder = os.path.splitext(output_video_name)[0]
os.makedirs(os.path.dirname(output_video_name), exist_ok=True) os.makedirs(os.path.dirname(output_video_name), exist_ok=True)
if args_dict['save_frames_check']: #if args_dict['save_frames_check']:
os.makedirs(output_video_folder, exist_ok=True) os.makedirs(output_video_folder, exist_ok=True)
# Writing to current params to params.json
setts_json = utils.export_settings(*args)
with open(os.path.join(output_video_folder, "params.json"), "w") as outfile:
outfile.write(setts_json)
curr_frame = None
prev_frame = None
def save_result_to_image(image, ind): def save_result_to_image(image, ind):
if args_dict['save_frames_check']: if args_dict['save_frames_check']:
cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
def set_cn_frame_input():
if args_dict['cn_frame_send'] == 0: # Current generated frame"
pass
elif args_dict['cn_frame_send'] == 1: # Current generated frame"
if curr_frame is not None:
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame), set_references=True)
elif args_dict['cn_frame_send'] == 2: # Previous generated frame
if prev_frame is not None:
utils.set_CNs_input_image(args_dict, Image.fromarray(prev_frame), set_references=True)
elif args_dict['cn_frame_send'] == 3: # Current reference video frame
if input_video is not None: if input_video is not None:
curr_video_frame = read_frame_from_video(input_video) curr_video_frame = read_frame_from_video(input_video)
curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height'])) curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height']))
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame)) utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame), set_references=True)
else:
raise Exception('There is no input video! Set it up first.')
else:
raise Exception('Incorrect cn_frame_send mode!')
set_cn_frame_input()
if args_dict['init_image'] is not None: if args_dict['init_image'] is not None:
#resize array to args_dict['width'], args_dict['height'] #resize array to args_dict['width'], args_dict['height']
@ -131,10 +157,18 @@ def start_process(*args):
pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy() pred_flow = flow_utils.flow_renorm(pred_data[...,:2]).cpu().numpy()
pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1) pred_occl = flow_utils.occl_renorm(pred_data[...,2:3]).cpu().numpy().repeat(3, axis = -1)
pred_next = flow_utils.frames_renorm(pred_data[...,3:6]).cpu().numpy()
pred_occl = np.clip(pred_occl * 10, 0, 255).astype(np.uint8)
pred_next = np.clip(pred_next, 0, 255).astype(np.uint8)
pred_flow = cv2.resize(pred_flow, org_size) pred_flow = cv2.resize(pred_flow, org_size)
pred_occl = cv2.resize(pred_occl, org_size) pred_occl = cv2.resize(pred_occl, org_size)
pred_next = cv2.resize(pred_next, org_size)
curr_frame = pred_next.copy()
'''
pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05) pred_flow = pred_flow / (1 + np.linalg.norm(pred_flow, axis=-1, keepdims=True) * 0.05)
pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101) pred_flow = cv2.GaussianBlur(pred_flow, (31,31), 1, cv2.BORDER_REFLECT_101)
@ -147,20 +181,21 @@ def start_process(*args):
flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis] flow_map[:,:,1] += np.arange(args_dict['height'])[:,np.newaxis]
warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT_101) warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT_101)
alpha_mask = pred_occl / 255.
#alpha_mask = np.clip(alpha_mask + np.random.normal(0, 0.4, size = alpha_mask.shape), 0, 1)
curr_frame = pred_next.astype(float) * alpha_mask + warped_frame.astype(float) * (1 - alpha_mask)
curr_frame = np.clip(curr_frame, 0, 255).astype(np.uint8)
#curr_frame = warped_frame.copy()
'''
curr_frame = warped_frame.copy() set_cn_frame_input()
args_dict['mode'] = 4 args_dict['mode'] = 4
args_dict['init_img'] = Image.fromarray(curr_frame) args_dict['init_img'] = Image.fromarray(pred_next)
args_dict['mask_img'] = Image.fromarray(pred_occl) args_dict['mask_img'] = Image.fromarray(pred_occl)
args_dict['seed'] = -1 args_dict['seed'] = -1
args_dict['denoising_strength'] = args_dict['processing_strength'] args_dict['denoising_strength'] = args_dict['processing_strength']
if input_video is not None:
curr_video_frame = read_frame_from_video(input_video)
curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height']))
utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame))
processed_frames, _, _, _ = utils.img2img(args_dict) processed_frames, _, _, _ = utils.img2img(args_dict)
processed_frame = np.array(processed_frames[0])[...,:3] processed_frame = np.array(processed_frames[0])[...,:3]
#if input_video is not None: #if input_video is not None:
@ -189,7 +224,7 @@ def start_process(*args):
save_result_to_image(processed_frame, ind + 2) save_result_to_image(processed_frame, ind + 2)
stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time) stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time)
yield stat, curr_frame, pred_occl, warped_frame, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True) yield stat, curr_frame, pred_occl, pred_next, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True)
if input_video is not None: input_video.release() if input_video is not None: input_video.release()
output_video.release() output_video.release()

View File

@ -11,7 +11,7 @@ def get_component_names():
'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier', 'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier',
'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed', 'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed',
't2v_file','t2v_init_image', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength', 't2v_file','t2v_init_image', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength',
't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 't2v_cn_frame_send',
'glo_save_frames_check' 'glo_save_frames_check'
] ]
@ -120,10 +120,10 @@ def get_mode_args(mode, args_dict):
return mode_args_dict return mode_args_dict
def set_CNs_input_image(args_dict, image): def set_CNs_input_image(args_dict, image, set_references = False):
for script_input in args_dict['script_inputs']: for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit': if type(script_input).__name__ == 'UiControlNetUnit':
if script_input.module not in ["reference_only", "reference_adain", "reference_adain+attn"]: if script_input.module not in ["reference_only", "reference_adain", "reference_adain+attn"] or set_references:
script_input.image = np.array(image) script_input.image = np.array(image)
script_input.batch_images = [np.array(image)] script_input.batch_images = [np.array(image)]
@ -391,3 +391,42 @@ def txt2img(args_dict):
# processed.images = [] # processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
import json
def get_json(obj):
return json.loads(
json.dumps(obj, default=lambda o: getattr(o, '__dict__', str(o)))
)
def export_settings(*args):
args_dict = args_to_dict(*args)
if args[0] == 'vid2vid':
args_dict = get_mode_args('v2v', args_dict)
elif args[0] == 'txt2vid':
args_dict = get_mode_args('t2v', args_dict)
else:
msg = f"Unsupported processing mode: '{args[0]}'"
raise Exception(msg)
# convert CN params into a readable dict
cn_remove_list = ['low_vram', 'is_ui', 'input_mode', 'batch_images', 'output_dir', 'loopback', 'image']
args_dict['ControlNets'] = []
for script_input in args_dict['script_inputs']:
if type(script_input).__name__ == 'UiControlNetUnit':
cn_values_dict = get_json(script_input)
if cn_values_dict['enabled']:
for key in cn_remove_list:
if key in cn_values_dict: del cn_values_dict[key]
args_dict['ControlNets'].append(cn_values_dict)
# remove unimportant values
remove_list = ['save_frames_check', 'restore_faces', 'prompt_styles', 'mask_blur', 'inpainting_fill', 'tiling', 'n_iter', 'batch_size', 'subseed', 'subseed_strength', 'seed_resize_from_h', \
'seed_resize_from_w', 'seed_enable_extras', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'file', 'denoising_strength', \
'override_settings', 'script_inputs', 'init_img', 'mask_img', 'mode', 'init_video']
for key in remove_list:
if key in args_dict: del args_dict[key]
return json.dumps(args_dict, indent=2, default=lambda o: getattr(o, '__dict__', str(o)))