animatediff live previews and implement interpolate

pull/2585/head
Vladimir Mandic 2023-12-01 13:16:23 -05:00
parent 52f453400a
commit 2552555678
8 changed files with 166 additions and 185 deletions

View File

@ -621,23 +621,31 @@ def save_image(image, path, basename='', seed=None, prompt=None, extension=share
return params.filename, filename_txt
def save_video_atomic(images, video_type, filename, duration, loop):
def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3):
try:
import cv2
except Exception as e:
shared.log.error(f'Save video: cv2: {e}')
return
os.makedirs(os.path.dirname(filename), exist_ok=True)
if video_type == 'mp4':
video_frames = [np.array(frame) for frame in images]
if video_type.lower() == 'mp4':
frames = images
if interpolate > 0:
try:
import modules.rife
frames = modules.rife.interpolate(images, count=interpolate, scale=scale, pad=pad, change=change)
except Exception as e:
shared.log.error(f'RIFE interpolation: {e}')
errors.display(e, 'RIFE interpolation')
video_frames = [np.array(frame) for frame in frames]
fourcc = "mp4v"
h, w, _c = video_frames[0].shape
video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(images)/duration, frameSize=(w, h))
video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
shared.log.info(f'Save video: file="{filename}" frames={len(images)} duration={duration} fourcc={fourcc}')
if video_type == 'gif':
shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc}')
if video_type.lower() == 'gif' or video_type.lower() == 'png':
append = images.copy()
image = append.pop(0)
if loop:
@ -654,8 +662,8 @@ def save_video_atomic(images, video_type, filename, duration, loop):
shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop}')
def save_video(p, images, video_type, filename = None, duration = 2, loop = True):
if images is None or len(images) < 2:
def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3):
if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none':
return
image = images[0]
namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image)
@ -663,7 +671,8 @@ def save_video(p, images, video_type, filename = None, duration = 2, loop = True
filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]")
filename = namegen.sanitize(os.path.join(shared.opts.outdir_video, filename))
filename = namegen.sequence(filename, shared.opts.outdir_video, '')
threading.Thread(target=save_video_atomic, args=(images, video_type, f'{filename}.{video_type}', duration, loop)).start()
filename = f'{filename}.{video_type.lower()}'
threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start()
def safe_decode_string(s: bytes):

117
modules/rife/__init__.py Normal file
View File

@ -0,0 +1,117 @@
#!/bin/env python
import _thread
import os
import time
from queue import Queue
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from tqdm.rich import tqdm
from modules.rife.ssim import ssim_matlab
from modules.rife.model_rife import Model
from modules import devices, shared
model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl'
model = None
def load(model_path: str = 'rife/flownet-v46.pkl'):
global model # pylint: disable=global-statement
if model is None:
from modules import modelloader
model_dir = os.path.join(shared.models_path, 'RIFE')
model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl')
shared.log.debug(f'RIFE load model: file="{model_path}"')
model = Model()
model.load_model(model_path, -1)
model.eval()
model.device()
def interpolate(images: list, count: int = 2, scale: float = 1.0, pad: int = 1, change: float = 0.3):
if images is None or len(images) < 2:
return []
if model is None:
load()
interpolated = []
h = images[0].height
w = images[0].width
t0 = time.time()
def write(buffer):
item = buffer.get()
while item is not None:
img = item[:, :, ::-1]
# image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
image = Image.fromarray(img)
item = buffer.get()
interpolated.append(image)
def execute(I0, I1, n):
if model.version >= 3.9:
res = []
for i in range(n):
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
return res
else:
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = execute(I0, middle, n=n//2)
second_half = execute(middle, I1, n=n//2)
if n % 2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def f_pad(img):
return F.pad(img, padding).to(devices.dtype) # pylint: disable=not-callable
tmp = max(128, int(128 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
buffer = Queue(maxsize=8192)
_thread.start_new_thread(write, (buffer,))
frame = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
for _i in range(pad): # fill starting frames
buffer.put(frame)
I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
with torch.no_grad():
with tqdm(total=len(images), desc='Interpolate', unit='frame') as pbar:
for image in images:
frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
I0 = I1
I1 = f_pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False).to(torch.float32)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if ssim > 0.99: # skip duplicate frames
continue
if ssim < change:
output = []
for _i in range(pad): # fill frames if change rate is above threshold
output.append(I0)
for _i in range(pad):
output.append(I1)
else:
output = execute(I0, I1, count-1)
for mid in output:
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
buffer.put(mid[:h, :w])
buffer.put(frame)
pbar.update(1)
for _i in range(pad): # fill ending frames
buffer.put(frame)
while not buffer.empty():
time.sleep(0.1)
t1 = time.time()
shared.log.info(f'RIFE interpolate: input={len(images)} frames={len(interpolated)} resolution={w}x{h} interpolate={count} scale={scale} pad={pad} change={change} time={round(t1 - t0, 2)}')
return interpolated

View File

@ -1,8 +1,8 @@
import torch
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
from model_ifnet import IFNet
from loss import EPE, SOBEL
from modules.rife.model_ifnet import IFNet
from modules.rife.loss import EPE, SOBEL
from modules import devices
@ -31,11 +31,7 @@ class Model:
def load_model(self, model_file, rank=0):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
return { k.replace("module.", ""): v for k, v in param.items() if "module." in k }
else:
return param
if rank <= 0:

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from warplayer import warp
from modules.rife.warplayer import warp
c = 16

View File

@ -1,153 +0,0 @@
#!/bin/env python
import _thread
import argparse
import os
import time
import tempfile
from queue import Queue
import filetype
import cv2
import numpy as np
import torch
from torch.nn import functional as F
from tqdm.rich import tqdm
from ssim import ssim_matlab
from model_rife import Model
from modules import devices
model = None
count = 0
def load(model_path: str = 'rife/flownet-v46.pkl'):
global model # pylint: disable=global-statement
model = Model()
model.load_model(model_path, -1)
model.eval()
model.device()
def interpolate(args): # pylint: disable=redefined-outer-name
print('start interpolate')
t0 = time.time()
if model is None:
load(args.model)
videogen = []
if args.seq is None:
for f in os.listdir(args.input):
fn = os.path.join(args.input, f)
if os.path.isfile(fn) and filetype.is_image(fn):
videogen.append(fn)
else:
files = sorted(os.listdir(args.input))
current = args.seq
for f in files:
seq = os.path.basename(f).split('-')[0]
if seq.isdigit() and int(seq) == current:
fn = os.path.join(args.input, f)
videogen.append(fn)
current += 1
videogen = sorted(videogen)
print(f'inputs: {len(videogen)} {[os.path.basename(f) for f in videogen]}')
# videogen.sort(key=lambda x:int(os.path.basename(x[:-4])))
frame = cv2.imread(videogen[0], cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
h, w, _ = frame.shape
if not os.path.exists(args.output):
os.mkdir(args.output)
def write(output_dir, buffer):
global count # pylint: disable=global-statement
item = buffer.get()
while item is not None:
cv2.imwrite(f'{output_dir}/{count:0>6d}.jpg', item[:, :, ::-1])
item = buffer.get()
count += 1
def execute(I0, I1, n):
if model.version >= 3.9:
res = []
for i in range(n):
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
return res
else:
middle = model.inference(I0, I1, args.scale)
if n == 1:
return [middle]
first_half = execute(I0, middle, n=n//2)
second_half = execute(middle, I1, n=n//2)
if n % 2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def pad(img):
return F.pad(img, padding).half() if args.fp16 else F.pad(img, padding) # pylint: disable=not-callable
tmp = max(128, int(128 / args.scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
buffer = Queue(maxsize=8192)
_thread.start_new_thread(write, (args.output, buffer))
print(f'padded start: frames={args.buffer}')
for _i in range(args.buffer): # fill starting frames
buffer.put(frame)
I1 = pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
with torch.no_grad():
with tqdm(total=len(videogen), desc='interpolate', unit='frame') as pbar:
for f in videogen:
frame = cv2.imread(f, cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
I0 = I1
I1 = pad(torch.from_numpy(np.transpose(frame, (2,0,1))).to(devices.device, non_blocking=True).unsqueeze(0).float() / 255.)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if ssim > 0.99: # skip duplicate frames
continue
if ssim < args.change:
output = []
for _i in range(args.buffer): # fill frames if change rate is above threshold
output.append(I0)
for _i in range(args.buffer):
output.append(I1)
else:
output = execute(I0, I1, args.multi-1)
for mid in output:
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
buffer.put(mid[:h, :w])
buffer.put(frame)
pbar.update(1)
print(f'padded end: frames={args.buffer}')
for _i in range(args.buffer): # fill ending frames
buffer.put(frame)
while not buffer.empty():
time.sleep(0.5)
t1 = time.time()
print(f'end interpolate: input={len(videogen)} frames={count} time={round(t1 - t0, 2)}')
if __name__ == "__main__":
print('starting rife')
tmp_folder = os.path.join(tempfile.gettempdir(), f'rife-{time.strftime("%Y%m%d-%H%M%S")}')
parser = argparse.ArgumentParser(description='interpolate video frames using RIFE')
parser.add_argument('--model', type=str, default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'model/flownet-v46.pkl')), help='path to model, default: %(default)s')
parser.add_argument('--input', type=str, required=True, default=None, help='input directory containing images, default: %(default)s')
parser.add_argument('--output', type=str, default=tmp_folder, help='output directory for interpolated images, default: %(default)s')
parser.add_argument('--scale', type=float, default=1.0, help='scale factor for interpolated images, default: %(default)s')
parser.add_argument('--multi', type=int, default=4, help='number of frames to interpolate between two input images, default: %(default)s')
parser.add_argument('--buffer', type=int, default=2, help='number of frames to buffer on scene change, default: %(default)s')
parser.add_argument('--change', type=float, default=0.3, help='scene change threshold (lower is more sensitive, default: %(default)s')
parser.add_argument('--fp16', action='store_true', help='use float16 precision instead of float32, default: %(default)s')
parser.add_argument('--fps', type=int, default=25, help='desired framerate, default: %(default)s')
parser.add_argument('--seq', type=int, default=None, help='image sequence start number, default: %(default)s')
parser.add_argument('--rm', action='store_true', help='remove interpolated images, default: %(default)s')
args = parser.parse_args()
print('args', args)
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
interpolate(args)

View File

@ -13,5 +13,5 @@ def warp(tenInput, tenFlow):
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(devices.device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
grid = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1).to(devices.dtype)
return torch.nn.functional.grid_sample(input=tenInput, grid=grid, mode='bilinear', padding_mode='border', align_corners=True)

View File

@ -31,12 +31,15 @@ def setup_img2img_steps(p, steps=None):
def single_sample_to_image(sample, approximation=None):
# sample should be [4,64,64]
if approximation is None:
approximation = approximation_indexes.get(shared.opts.show_progress_type, None)
if approximation is None:
warn_once('Unknown decode type, please reset preview method')
approximation = 0
if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent
sample = sample.permute(1, 0, 2, 3)[0]
if approximation == 0: # Simple
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
elif approximation == 1: # Approximate

View File

@ -106,8 +106,16 @@ class Script(scripts.Script):
def show(self, _is_img2img):
return scripts.AlwaysVisible if shared.backend == shared.Backend.DIFFUSERS else False
# return signature is array of gradio components
def ui(self, _is_img2img):
def video_type_change(video_type):
return [
gr.update(visible=video_type != 'None'),
gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
gr.update(visible=video_type == 'MP4'),
gr.update(visible=video_type == 'MP4'),
]
with gr.Accordion('AnimateDiff', open=False, elem_id='animatediff'):
with gr.Row():
adapter_index = gr.Dropdown(label='Adapter', choices=list(ADAPTERS), value='None')
@ -118,19 +126,22 @@ class Script(scripts.Script):
with gr.Row():
latent_mode = gr.Checkbox(label='Latent mode', value=False)
with gr.Row():
create_gif = gr.Checkbox(label='GIF', value=False)
create_mp4 = gr.Checkbox(label='MP4', value=False)
loop = gr.Checkbox(label='Loop', value=True)
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2)
return [adapter_index, frames, lora_index, strength, latent_mode, create_gif, create_mp4, duration, loop]
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
with gr.Row():
gif_loop = gr.Checkbox(label='Loop', value=True, visible=False)
mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate])
return [adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate]
def process(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lora_index, strength, latent_mode, create_gif, create_mp4, duration, loop): # pylint: disable=arguments-differ, unused-argument
def process(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument
adapter = ADAPTERS[adapter_index]
lora = LORAS[lora_index]
set_adapter(adapter)
if motion_adapter is None:
return
shared.log.debug(f'AnimateDiff: adapter="{adapter}" lora="{lora}" strength={strength} gif={create_gif} mp4={create_mp4}')
shared.log.debug(f'AnimateDiff: adapter="{adapter}" lora="{lora}" strength={strength} video={video_type}')
if lora is not None and lora != 'None':
shared.sd_model.load_lora_weights(lora, adapter_name=lora)
shared.sd_model.set_adapters([lora], adapter_weights=[strength])
@ -142,9 +153,7 @@ class Script(scripts.Script):
if not latent_mode:
p.task_args['output_type'] = 'np'
def postprocess(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, adapter_index, frames, lora_index, strength, latent_mode, create_gif, create_mp4, duration, loop): # pylint: disable=arguments-differ, unused-argument
def postprocess(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument
from modules.images import save_video
if create_gif:
save_video(p, images=processed.images, video_type='gif', duration=duration, loop=loop)
if create_mp4:
save_video(p, images=processed.images, video_type='mp4', duration=duration, loop=loop)
if video_type != 'None':
save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate)