fix last frame (#116)

pull/117/head v1.5.2
Chengsong Zhang 2023-09-22 23:19:44 -05:00 committed by GitHub
parent 4b8fc1551f
commit 81f11fad90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 7 deletions

View File

@ -1,3 +1,4 @@
import numpy as np
import torch
from modules import images, shared
from modules.devices import device, dtype_vae, torch_gc
@ -25,13 +26,17 @@ class AnimateDiffI2VLatent:
init_alpha[init_alpha < 0] = 0
if params.last_frame is not None:
last_frame = params.last_frame
if type(last_frame) == str:
from modules.api.api import decode_base64_to_image
last_frame = decode_base64_to_image(last_frame)
# Get last_alpha
last_alpha = [
1 - pow(i, params.latent_power_last) / params.latent_scale_last
for i in range(params.last_frame)
for i in range(params.video_length)
]
last_alpha.reverse()
logger.info(f"Randomizing last_latent according to {init_alpha}.")
logger.info(f"Randomizing last_latent according to {last_alpha}.")
last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[
:, None, None, None
]
@ -43,13 +48,18 @@ class AnimateDiffI2VLatent:
scaling_factor = 1 / sum_alpha[mask_alpha]
init_alpha[mask_alpha] *= scaling_factor
last_alpha[mask_alpha] *= scaling_factor
init_alpha[0] = 1
init_alpha[-1] = 0
last_alpha[0] = 0
last_alpha[-1] = 1
# Calculate last_latent
last_frame = params.last_frame
if p.resize_mode != 3:
last_frame = images.resize_image(
p.resize_mode, last_frame, p.width, p.height
)[None, ...]
)
last_frame = np.array(last_frame).astype(np.float32) / 255.0
last_frame = np.moveaxis(last_frame, 2, 0)[None, ...]
last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae)
last_latent = images_tensor_to_samples(
last_frame,
@ -64,11 +74,10 @@ class AnimateDiffI2VLatent:
size=(p.height // opt_f, p.width // opt_f),
mode="bilinear",
)
# Modify init_latent
p.init_latent = (
p.init_latent * init_alpha
+ p.last_latent * last_alpha
+ last_latent * last_alpha
+ p.rng.next() * (1 - init_alpha - last_alpha)
)
else:

View File

@ -164,7 +164,8 @@ class AnimateDiffUiGroup:
label="Optional latent scale for last frame",
)
self.params.last_frame = gr.Image(
label="[Experiment] Optional last frame. Leave it blank if you do not need one."
label="[Experiment] Optional last frame. Leave it blank if you do not need one.",
type="pil",
)
with gr.Row():
unload = gr.Button(