Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Video dataset functionalities #1

Open
wants to merge 6 commits into
base: video-reader
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import contextlib
import os
import torch
import unittest

from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold

from common_utils import get_tmp_dir


@contextlib.contextmanager
def get_list_of_videos(num_videos=5):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
data = torch.randint(0, 255, (5 * (i + 1), 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=5)

yield names


class Tester(unittest.TestCase):

def test_unfold(self):
a = torch.arange(7)

r = unfold(a, 3, 3, 1)
expected = torch.tensor([
[0, 1, 2],
[3, 4, 5],
])
self.assertTrue(r.equal(expected))

r = unfold(a, 3, 2, 1)
expected = torch.tensor([
[0, 1, 2],
[2, 3, 4],
[4, 5, 6]
])
self.assertTrue(r.equal(expected))

r = unfold(a, 3, 2, 2)
expected = torch.tensor([
[0, 2, 4],
[2, 4, 6],
])
self.assertTrue(r.equal(expected))


def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 6)
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 1)
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .sbd import SBDataset
from .vision import VisionDataset
from .usps import USPS
from .kinetics import KineticsVideo

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
Expand All @@ -28,4 +29,4 @@
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS')
'USPS', 'KineticsVideo')
26 changes: 26 additions & 0 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .vision import VisionDataset


class KineticsVideo(VisionDataset):
def __init__(self, root, frames_per_clip, step_between_clips=1):
super(KineticsVideo, self).__init__(root)
extensions = ('avi',)

classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
self.class_to_idx = class_to_idx
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)

def __len__(self):
return self.video_clips.num_clips()

def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1]

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add transforms yet

return video, audio, label
110 changes: 110 additions & 0 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import bisect
import torch
from torchvision.io import read_video_timestamps, read_video


def unfold(tensor, size, step, dilation):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uses stride tricks to compute all the possible clips in the video, with potential steps between clips and dilation (steps between frames)

"""
similar to tensor.unfold, but with the dilation
and specialized for 1d tensors
"""
assert tensor.dim() == 1
o_stride = tensor.stride(0)
numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride)
new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
if new_size[0] < 1:
new_size = (0, size)
return torch.as_strided(tensor, new_size, new_stride)


class VideoClips(object):
"""
Given a list of video files, computes all consecutive subvideos of size
`clip_length_in_frames`, where the distance between each subvideo in the
same video is defined by `frames_between_clips`.

Creating this instance the first time is time-consuming, as it needs to
decode all the videos in `video_paths`. It is recommended that you
cache the results after instantiation of the class.

Recreating the clips for different clip lengths is fast, and can be done
with the `compute_clips` method.
"""
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1):
self.video_paths = video_paths
self._compute_frame_pts()
self.compute_clips(clip_length_in_frames, frames_between_clips)

def _compute_frame_pts(self):
self.video_pts = []
# TODO maybe paralellize this
for video_file in self.video_paths:
clips = read_video_timestamps(video_file)
self.video_pts.append(torch.as_tensor(clips))

def compute_clips(self, num_frames, step, dilation=1):
"""
Compute all consecutive sequences of clips from video_pts.

Arguments:
num_frames (int): number of frames for the clip
step (int): distance between two clips
dilation (int): distance between two consecutive frames
in a clip
"""
self.num_frames = num_frames
self.step = step
self.dilation = dilation
self.clips = []
for video_pts in self.video_pts:
clips = unfold(video_pts, num_frames, step, dilation)
self.clips.append(clips)
l = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = l.cumsum(0).tolist()

def __len__(self):
return self.num_clips()

def num_videos(self):
return len(self.video_paths)

def num_clips(self):
"""
Number of subclips that are available in the video list.
"""
return self.cumulative_sizes[-1]

def get_clip_location(self, idx):
"""
Converts a flattened representation of the indices into a video_idx, clip_idx
representation.
"""
video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if video_idx == 0:
clip_idx = idx
else:
clip_idx = idx - self.cumulative_sizes[video_idx - 1]
return video_idx, clip_idx

def get_clip(self, idx):
"""
Gets a subclip from a list of videos.

Arguments:
idx (int): index of the subclip. Must be between 0 and num_clips().

Returns:
video (Tensor)
audio (Tensor)
info (Dict)
video_idx (int): index of the video in `video_paths`
"""
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
video, audio, info = read_video(video_path, clip_pts[0].item(), clip_pts[-1].item())
video = video[::self.dilation]
# TODO change video_fps in info?
assert len(video) == self.num_frames
return video, audio, info, video_idx