189 lines
7.5 KiB
Python
189 lines
7.5 KiB
Python
import os
|
|
import random
|
|
import bisect
|
|
|
|
import pandas as pd
|
|
|
|
import omegaconf
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
from decord import VideoReader, cpu
|
|
import torchvision.transforms._transforms_video as transforms_video
|
|
|
|
class WebVid(Dataset):
|
|
"""
|
|
WebVid Dataset.
|
|
Assumes webvid data is structured as follows.
|
|
Webvid/
|
|
videos/
|
|
000001_000050/ ($page_dir)
|
|
1.mp4 (videoid.mp4)
|
|
...
|
|
5000.mp4
|
|
...
|
|
"""
|
|
def __init__(self,
|
|
meta_path,
|
|
data_dir,
|
|
subsample=None,
|
|
video_length=16,
|
|
resolution=[256, 512],
|
|
frame_stride=1,
|
|
spatial_transform=None,
|
|
crop_resolution=None,
|
|
fps_max=None,
|
|
load_raw_resolution=False,
|
|
fps_schedule=None,
|
|
fs_probs=None,
|
|
bs_per_gpu=None,
|
|
trigger_word='',
|
|
dataname='',
|
|
):
|
|
self.meta_path = meta_path
|
|
self.data_dir = data_dir
|
|
self.subsample = subsample
|
|
self.video_length = video_length
|
|
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution
|
|
self.frame_stride = frame_stride
|
|
self.fps_max = fps_max
|
|
self.load_raw_resolution = load_raw_resolution
|
|
self.fs_probs = fs_probs
|
|
self.trigger_word = trigger_word
|
|
self.dataname = dataname
|
|
|
|
self._load_metadata()
|
|
if spatial_transform is not None:
|
|
if spatial_transform == "random_crop":
|
|
self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution)
|
|
elif spatial_transform == "resize_center_crop":
|
|
assert(self.resolution[0] == self.resolution[1])
|
|
self.spatial_transform = transforms.Compose([
|
|
transforms.Resize(resolution),
|
|
transforms_video.CenterCropVideo(resolution),
|
|
])
|
|
else:
|
|
raise NotImplementedError
|
|
else:
|
|
self.spatial_transform = None
|
|
|
|
self.fps_schedule = fps_schedule
|
|
self.bs_per_gpu = bs_per_gpu
|
|
if self.fps_schedule is not None:
|
|
assert(self.bs_per_gpu is not None)
|
|
self.counter = 0
|
|
self.stage_idx = 0
|
|
|
|
def _load_metadata(self):
|
|
metadata = pd.read_csv(self.meta_path)
|
|
if self.subsample is not None:
|
|
metadata = metadata.sample(self.subsample, random_state=0)
|
|
metadata['caption'] = metadata['name']
|
|
del metadata['name']
|
|
self.metadata = metadata
|
|
self.metadata.dropna(inplace=True)
|
|
# self.metadata['caption'] = self.metadata['caption'].str[:350]
|
|
|
|
def _get_video_path(self, sample):
|
|
if self.dataname == "loradata":
|
|
rel_video_fp = str(sample['videoid']) + '.mp4'
|
|
full_video_fp = os.path.join(self.data_dir, rel_video_fp)
|
|
else:
|
|
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
|
|
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
|
|
return full_video_fp, rel_video_fp
|
|
|
|
def get_fs_based_on_schedule(self, frame_strides, schedule):
|
|
assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1
|
|
global_step = self.counter // self.bs_per_gpu # TODO: support resume.
|
|
stage_idx = bisect.bisect(schedule, global_step)
|
|
frame_stride = frame_strides[stage_idx]
|
|
# log stage change
|
|
if stage_idx != self.stage_idx:
|
|
print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}')
|
|
self.stage_idx = stage_idx
|
|
return frame_stride
|
|
|
|
def get_fs_based_on_probs(self, frame_strides, probs):
|
|
assert(len(frame_strides) == len(probs))
|
|
return random.choices(frame_strides, weights=probs)[0]
|
|
|
|
def get_fs_randomly(self, frame_strides):
|
|
return random.choice(frame_strides)
|
|
|
|
def __getitem__(self, index):
|
|
|
|
if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig):
|
|
if self.fps_schedule is not None:
|
|
frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule)
|
|
elif self.fs_probs is not None:
|
|
frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs)
|
|
else:
|
|
frame_stride = self.get_fs_randomly(self.frame_stride)
|
|
else:
|
|
frame_stride = self.frame_stride
|
|
assert(isinstance(frame_stride, int)), type(frame_stride)
|
|
|
|
while True:
|
|
index = index % len(self.metadata)
|
|
sample = self.metadata.iloc[index]
|
|
video_path, rel_fp = self._get_video_path(sample)
|
|
caption = sample['caption']+self.trigger_word
|
|
|
|
# make reader
|
|
try:
|
|
if self.load_raw_resolution:
|
|
video_reader = VideoReader(video_path, ctx=cpu(0))
|
|
else:
|
|
video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
|
|
if len(video_reader) < self.video_length:
|
|
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})")
|
|
index += 1
|
|
continue
|
|
else:
|
|
pass
|
|
except:
|
|
index += 1
|
|
print(f"Load video failed! path = {video_path}")
|
|
continue
|
|
|
|
# sample strided frames
|
|
all_frames = list(range(0, len(video_reader), frame_stride))
|
|
if len(all_frames) < self.video_length: # recal a max fs
|
|
frame_stride = len(video_reader) // self.video_length
|
|
assert(frame_stride != 0)
|
|
all_frames = list(range(0, len(video_reader), frame_stride))
|
|
|
|
# select a random clip
|
|
rand_idx = random.randint(0, len(all_frames) - self.video_length)
|
|
frame_indices = all_frames[rand_idx:rand_idx+self.video_length]
|
|
try:
|
|
frames = video_reader.get_batch(frame_indices)
|
|
break
|
|
except:
|
|
print(f"Get frames failed! path = {video_path}")
|
|
index += 1
|
|
continue
|
|
|
|
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
|
|
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
|
|
if self.spatial_transform is not None:
|
|
frames = self.spatial_transform(frames)
|
|
if self.resolution is not None:
|
|
assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
|
|
frames = (frames / 255 - 0.5) * 2
|
|
|
|
fps_ori = video_reader.get_avg_fps()
|
|
fps_clip = fps_ori // frame_stride
|
|
if self.fps_max is not None and fps_clip > self.fps_max:
|
|
fps_clip = self.fps_max
|
|
|
|
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride}
|
|
|
|
if self.fps_schedule is not None:
|
|
self.counter += 1
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.metadata)
|