Skip to content

Commit 9fd0b90

Browse files
shi69Atcold
authored andcommitted
Add optional shuffling
1 parent e8b8270 commit 9fd0b90

File tree

1 file changed

+70
-17
lines changed

1 file changed

+70
-17
lines changed

data/VideoFolder.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import torch
33
import torch.utils.data as data
44

5+
from random import shuffle as list_shuffle # for shuffling list
56
from math import ceil
67
from os import listdir
7-
from os.path import isdir, join
8+
from os.path import isdir, join, isfile
89
from itertools import islice
910
from numpy.core.multiarray import concatenate, ndarray
1011
from skvideo.io import FFmpegReader, ffprobe
@@ -68,7 +69,7 @@ def __call__(self, batch: iter) -> torch.Tensor or list(torch.Tensor):
6869

6970

7071
class VideoFolder(data.Dataset):
71-
def __init__(self, root, transform=None, target_transform=None, video_index=False):
72+
def __init__(self, root, transform=None, target_transform=None, video_index=False, shuffle=None):
7273
"""
7374
Initialise a data.Dataset object for concurrent frame fetching from videos in a directory of folders of videos
7475
@@ -79,12 +80,16 @@ def __init__(self, root, transform=None, target_transform=None, video_index=Fals
7980
:param target_transform: label transformation / mapping
8081
:type target_transform: object
8182
:param video_index: if True, the label will be the video index instead of target class
82-
:type bool
83+
:type video_index: bool
84+
:param shuffle: None, 'init' or 'always'
85+
:type shuffle: str
8386
"""
8487
classes, class_to_idx = self._find_classes(root)
85-
videos, frames, frames_per_video = self._make_data_set(root, classes, class_to_idx)
88+
video_paths = self._find_videos(root, classes)
89+
videos, frames, frames_per_video = self._make_data_set(root, video_paths, class_to_idx, shuffle == 'init')
8690

8791
self.root = root
92+
self.video_paths = video_paths
8893
self.videos = videos
8994
self.opened_videos = [[] for _ in videos]
9095
self.frames = frames
@@ -94,8 +99,14 @@ def __init__(self, root, transform=None, target_transform=None, video_index=Fals
9499
self.transform = transform
95100
self.target_transform = target_transform
96101
self.alternative_target = video_index
102+
self.shuffle = shuffle
97103

98104
def __getitem__(self, frame_idx):
105+
if frame_idx == 0:
106+
self.free()
107+
if self.shuffle == 'always':
108+
self._shuffle()
109+
99110
frame_idx %= self.frames # wrap around indexing, if asking too much
100111
video_idx = bisect(self.videos, ((frame_idx,),)) # video to which frame_idx belongs
101112
(last, first), (path, target) = self.videos[video_idx] # get video metadata
@@ -138,11 +149,47 @@ def free(self):
138149
"""
139150
Frees all video files' pointers
140151
"""
152+
print('Resetting data set internal state')
141153
for video in self.opened_videos: # for every opened video
142154
for _ in range(len(video)): # for as many times as pointers
143155
opened_video = video.pop() # pop an item
144156
opened_video[2]._close() # close the file
145157

158+
def _shuffle(self):
159+
"""
160+
Shuffles the video list
161+
by regenerating the sequence to sample sequentially
162+
"""
163+
def _is_video_file(filename_):
164+
return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS)
165+
166+
root = self.root
167+
video_paths = self.video_paths
168+
class_to_idx = self.class_to_idx
169+
list_shuffle(video_paths) # shuffle
170+
171+
videos = list()
172+
frames_per_video = list()
173+
frames_counter = 0
174+
for filename in tqdm(video_paths, ncols=80):
175+
class_ = filename.split('/')[0]
176+
data_path = join(root, filename)
177+
if _is_video_file(data_path):
178+
video_meta = ffprobe(data_path)
179+
start_idx = frames_counter
180+
frames = int(video_meta['video'].get('@nb_frames'))
181+
frames_per_video.append(frames)
182+
frames_counter += frames
183+
item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_]))
184+
videos.append(item)
185+
186+
sleep(0.5) # allows for progress bar completion
187+
# update the attributes with the altered sequence
188+
self.video_paths = video_paths
189+
self.videos = videos
190+
self.frames = frames_counter
191+
self.frames_per_video = frames_per_video
192+
146193
@staticmethod
147194
def _find_classes(data_path):
148195
classes = [d for d in listdir(data_path) if isdir(join(data_path, d))]
@@ -151,25 +198,31 @@ def _find_classes(data_path):
151198
return classes, class_to_idx
152199

153200
@staticmethod
154-
def _make_data_set(data_path, classes, class_to_idx):
201+
def _find_videos(root, classes):
202+
return [join(c, d) for c in classes for d in listdir(join(root, c))]
203+
204+
@staticmethod
205+
def _make_data_set(root, video_paths, class_to_idx, init_shuffle):
155206
def _is_video_file(filename_):
156207
return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS)
157208

209+
if init_shuffle:
210+
list_shuffle(video_paths) # shuffle
211+
158212
videos = list()
159213
frames_per_video = list()
160214
frames_counter = 0
161-
for class_ in tqdm(classes, ncols=80):
162-
class_path = join(data_path, class_)
163-
for filename in listdir(class_path):
164-
if _is_video_file(filename):
165-
video_path = join(class_path, filename)
166-
video_meta = ffprobe(video_path)
167-
start_idx = frames_counter
168-
frames = int(video_meta['video'].get('@nb_frames'))
169-
frames_per_video.append(frames)
170-
frames_counter += frames
171-
item = ((frames_counter - 1, start_idx), (join(class_, filename), class_to_idx[class_]))
172-
videos.append(item)
215+
for filename in tqdm(video_paths, ncols=80):
216+
class_ = filename.split('/')[0]
217+
data_path = join(root, filename)
218+
if _is_video_file(data_path):
219+
video_meta = ffprobe(data_path)
220+
start_idx = frames_counter
221+
frames = int(video_meta['video'].get('@nb_frames'))
222+
frames_per_video.append(frames)
223+
frames_counter += frames
224+
item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_]))
225+
videos.append(item)
173226

174227
sleep(0.5) # allows for progress bar completion
175228
return videos, frames_counter, frames_per_video

0 commit comments

Comments
 (0)