263 lines
12 KiB
Python
263 lines
12 KiB
Python
import os
|
|
import time
|
|
import argparse
|
|
import yaml, math
|
|
from tqdm import trange
|
|
import torch
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
import torch.distributed as dist
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from videocrafter.lvdm.samplers.ddim import DDIMSampler
|
|
from videocrafter.lvdm.utils.common_utils import str2bool
|
|
from videocrafter.lvdm.utils.dist_utils import setup_dist, gather_data
|
|
from videocrafter.lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
|
|
from videocrafter.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
# basic args
|
|
parser.add_argument("--ckpt_path", type=str, help="model checkpoint path")
|
|
parser.add_argument("--config_path", type=str, help="model config path (a yaml file)")
|
|
parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).")
|
|
parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/")
|
|
# device args
|
|
parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False)
|
|
parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0)
|
|
parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0)
|
|
# sampling args
|
|
parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2)
|
|
parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1)
|
|
parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1)
|
|
parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"])
|
|
parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50)
|
|
parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0)
|
|
parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale")
|
|
parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)")
|
|
parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",)
|
|
parser.add_argument("--num_frames", type=int, default=16, help="number of input frames")
|
|
# lora args
|
|
parser.add_argument("--lora_path", type=str, help="lora checkpoint path")
|
|
parser.add_argument("--inject_lora", action='store_true', default=False, help="",)
|
|
parser.add_argument("--lora_scale", type=float, default=None, help="scale for lora weight")
|
|
parser.add_argument("--lora_trigger_word", type=str, default="", help="",)
|
|
# saving args
|
|
parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"])
|
|
parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",)
|
|
parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",)
|
|
parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",)
|
|
parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",)
|
|
return parser
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
def sample_denoising_batch(model, noise_shape, condition, *args,
|
|
sample_type="ddim", sampler=None,
|
|
ddim_steps=None, eta=None,
|
|
unconditional_guidance_scale=1.0, uc=None,
|
|
denoising_progress=False,
|
|
**kwargs,
|
|
):
|
|
|
|
if sample_type == "ddpm":
|
|
samples = model.p_sample_loop(cond=condition, shape=noise_shape,
|
|
return_intermediates=False,
|
|
verbose=denoising_progress,
|
|
**kwargs,
|
|
)
|
|
elif sample_type == "ddim":
|
|
assert(sampler is not None)
|
|
assert(ddim_steps is not None)
|
|
assert(eta is not None)
|
|
ddim_sampler = sampler
|
|
samples, _ = ddim_sampler.sample(S=ddim_steps,
|
|
conditioning=condition,
|
|
batch_size=noise_shape[0],
|
|
shape=noise_shape[1:],
|
|
verbose=denoising_progress,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=eta,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
raise ValueError
|
|
return samples
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
@torch.no_grad()
|
|
def sample_text2video(model, prompt, n_prompt, n_samples, batch_size,
|
|
sample_type="ddim", sampler=None,
|
|
ddim_steps=50, eta=1.0, cfg_scale=7.5,
|
|
decode_frame_bs=1,
|
|
ddp=False, all_gather=True,
|
|
batch_progress=True, show_denoising_progress=False,
|
|
num_frames=None,
|
|
):
|
|
# get cond vector
|
|
assert(model.cond_stage_model is not None)
|
|
cond_embd = get_conditions(prompt, model, batch_size)
|
|
uncond_embd = get_conditions(n_prompt, model, batch_size) if cfg_scale != 1.0 else None
|
|
|
|
# sample batches
|
|
all_videos = []
|
|
n_iter = math.ceil(n_samples / batch_size)
|
|
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
|
|
for _ in iterator:
|
|
noise_shape = make_model_input_shape(model, batch_size, T=num_frames)
|
|
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
|
|
sample_type=sample_type,
|
|
sampler=sampler,
|
|
ddim_steps=ddim_steps,
|
|
eta=eta,
|
|
unconditional_guidance_scale=cfg_scale,
|
|
uc=uncond_embd,
|
|
denoising_progress=show_denoising_progress,
|
|
)
|
|
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
|
|
|
|
# gather samples from multiple gpus
|
|
if ddp and all_gather:
|
|
data_list = gather_data(samples, return_np=False)
|
|
all_videos.extend([torch_to_np(data) for data in data_list])
|
|
else:
|
|
all_videos.append(torch_to_np(samples))
|
|
|
|
all_videos = np.concatenate(all_videos, axis=0)
|
|
assert(all_videos.shape[0] >= n_samples)
|
|
return all_videos
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
def save_results(videos, save_dir,
|
|
save_name="results", save_fps=8, save_mp4=True,
|
|
save_npz=False, save_mp4_sheet=False, save_jpg=False
|
|
):
|
|
if save_mp4:
|
|
save_subdir = os.path.join(save_dir, "videos")
|
|
os.makedirs(save_subdir, exist_ok=True)
|
|
for i in range(videos.shape[0]):
|
|
npz_to_video_grid(videos[i:i+1,...],
|
|
os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"),
|
|
fps=save_fps)
|
|
print(f'Successfully saved videos in {save_subdir}')
|
|
|
|
if save_npz:
|
|
save_path = os.path.join(save_dir, f"{save_name}.npz")
|
|
np.savez(save_path, videos)
|
|
print(f'Successfully saved npz in {save_path}')
|
|
|
|
if save_mp4_sheet:
|
|
save_path = os.path.join(save_dir, f"{save_name}.mp4")
|
|
npz_to_video_grid(videos, save_path, fps=save_fps)
|
|
print(f'Successfully saved mp4 sheet in {save_path}')
|
|
|
|
if save_jpg:
|
|
save_path = os.path.join(save_dir, f"{save_name}.jpg")
|
|
npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1])
|
|
print(f'Successfully saved jpg sheet in {save_path}')
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
def main():
|
|
"""
|
|
text-to-video generation
|
|
"""
|
|
parser = get_parser()
|
|
opt, unknown = parser.parse_known_args()
|
|
os.makedirs(opt.save_dir, exist_ok=True)
|
|
|
|
# set device
|
|
if opt.ddp:
|
|
setup_dist(opt.local_rank)
|
|
opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size())
|
|
gpu_id = None
|
|
else:
|
|
gpu_id = opt.gpu_id
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
|
|
|
# set random seed
|
|
if opt.seed is not None:
|
|
if opt.ddp:
|
|
seed = opt.local_rank + opt.seed
|
|
else:
|
|
seed = opt.seed
|
|
seed_everything(seed)
|
|
|
|
# dump args
|
|
fpath = os.path.join(opt.save_dir, "sampling_args.yaml")
|
|
with open(fpath, 'w') as f:
|
|
yaml.dump(vars(opt), f, default_flow_style=False)
|
|
|
|
# load & merge config
|
|
config = OmegaConf.load(opt.config_path)
|
|
cli = OmegaConf.from_dotlist(unknown)
|
|
config = OmegaConf.merge(config, cli)
|
|
print("config: \n", config)
|
|
|
|
# get model & sampler
|
|
model, _, _ = load_model(config, opt.ckpt_path,
|
|
inject_lora=opt.inject_lora,
|
|
lora_scale=opt.lora_scale,
|
|
lora_path=opt.lora_path
|
|
)
|
|
ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None
|
|
|
|
# prepare prompt
|
|
if opt.prompt.endswith(".txt"):
|
|
opt.prompt_file = opt.prompt
|
|
opt.prompt = None
|
|
else:
|
|
opt.prompt_file = None
|
|
|
|
if opt.prompt_file is not None:
|
|
f = open(opt.prompt_file, 'r')
|
|
prompts, line_idx = [], []
|
|
for idx, line in enumerate(f.readlines()):
|
|
l = line.strip()
|
|
if len(l) != 0:
|
|
prompts.append(l)
|
|
line_idx.append(idx)
|
|
f.close()
|
|
cmd = f"cp {opt.prompt_file} {opt.save_dir}"
|
|
os.system(cmd)
|
|
else:
|
|
prompts = [opt.prompt]
|
|
line_idx = [None]
|
|
|
|
if opt.inject_lora:
|
|
assert(opt.lora_trigger_word != '')
|
|
prompts = [p + opt.lora_trigger_word for p in prompts]
|
|
|
|
# go
|
|
start = time.time()
|
|
for prompt in prompts:
|
|
# sample
|
|
samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size,
|
|
sample_type=opt.sample_type, sampler=ddim_sampler,
|
|
ddim_steps=opt.ddim_steps, eta=opt.eta,
|
|
cfg_scale=opt.cfg_scale,
|
|
decode_frame_bs=opt.decode_frame_bs,
|
|
ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress,
|
|
num_frames=opt.num_frames,
|
|
)
|
|
# save
|
|
if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp):
|
|
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
|
|
save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
|
|
if opt.seed is not None:
|
|
save_name = save_name + f"_seed{seed:05d}"
|
|
save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps)
|
|
print("Finish sampling!")
|
|
print(f"Run time = {(time.time() - start):.2f} seconds")
|
|
|
|
if opt.ddp:
|
|
dist.destroy_process_group()
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
if __name__ == "__main__":
|
|
main() |