2
2
import torch
3
3
import torch .utils .data as data
4
4
5
+ from random import shuffle as list_shuffle # for shuffling list
5
6
from math import ceil
6
7
from os import listdir
7
- from os .path import isdir , join
8
+ from os .path import isdir , join , isfile
8
9
from itertools import islice
9
10
from numpy .core .multiarray import concatenate , ndarray
10
11
from skvideo .io import FFmpegReader , ffprobe
@@ -68,7 +69,7 @@ def __call__(self, batch: iter) -> torch.Tensor or list(torch.Tensor):
68
69
69
70
70
71
class VideoFolder (data .Dataset ):
71
- def __init__ (self , root , transform = None , target_transform = None , video_index = False ):
72
+ def __init__ (self , root , transform = None , target_transform = None , video_index = False , shuffle = None ):
72
73
"""
73
74
Initialise a data.Dataset object for concurrent frame fetching from videos in a directory of folders of videos
74
75
@@ -79,12 +80,16 @@ def __init__(self, root, transform=None, target_transform=None, video_index=Fals
79
80
:param target_transform: label transformation / mapping
80
81
:type target_transform: object
81
82
:param video_index: if True, the label will be the video index instead of target class
82
- :type bool
83
+ :type video_index: bool
84
+ :param shuffle: None, 'init' or 'always'
85
+ :type shuffle: str
83
86
"""
84
87
classes , class_to_idx = self ._find_classes (root )
85
- videos , frames , frames_per_video = self ._make_data_set (root , classes , class_to_idx )
88
+ video_paths = self ._find_videos (root , classes )
89
+ videos , frames , frames_per_video = self ._make_data_set (root , video_paths , class_to_idx , shuffle == 'init' )
86
90
87
91
self .root = root
92
+ self .video_paths = video_paths
88
93
self .videos = videos
89
94
self .opened_videos = [[] for _ in videos ]
90
95
self .frames = frames
@@ -94,8 +99,14 @@ def __init__(self, root, transform=None, target_transform=None, video_index=Fals
94
99
self .transform = transform
95
100
self .target_transform = target_transform
96
101
self .alternative_target = video_index
102
+ self .shuffle = shuffle
97
103
98
104
def __getitem__ (self , frame_idx ):
105
+ if frame_idx == 0 :
106
+ self .free ()
107
+ if self .shuffle == 'always' :
108
+ self ._shuffle ()
109
+
99
110
frame_idx %= self .frames # wrap around indexing, if asking too much
100
111
video_idx = bisect (self .videos , ((frame_idx ,),)) # video to which frame_idx belongs
101
112
(last , first ), (path , target ) = self .videos [video_idx ] # get video metadata
@@ -138,11 +149,47 @@ def free(self):
138
149
"""
139
150
Frees all video files' pointers
140
151
"""
152
+ print ('Resetting data set internal state' )
141
153
for video in self .opened_videos : # for every opened video
142
154
for _ in range (len (video )): # for as many times as pointers
143
155
opened_video = video .pop () # pop an item
144
156
opened_video [2 ]._close () # close the file
145
157
158
+ def _shuffle (self ):
159
+ """
160
+ Shuffles the video list
161
+ by regenerating the sequence to sample sequentially
162
+ """
163
+ def _is_video_file (filename_ ):
164
+ return any (filename_ .endswith (extension ) for extension in VIDEO_EXTENSIONS )
165
+
166
+ root = self .root
167
+ video_paths = self .video_paths
168
+ class_to_idx = self .class_to_idx
169
+ list_shuffle (video_paths ) # shuffle
170
+
171
+ videos = list ()
172
+ frames_per_video = list ()
173
+ frames_counter = 0
174
+ for filename in tqdm (video_paths , ncols = 80 ):
175
+ class_ = filename .split ('/' )[0 ]
176
+ data_path = join (root , filename )
177
+ if _is_video_file (data_path ):
178
+ video_meta = ffprobe (data_path )
179
+ start_idx = frames_counter
180
+ frames = int (video_meta ['video' ].get ('@nb_frames' ))
181
+ frames_per_video .append (frames )
182
+ frames_counter += frames
183
+ item = ((frames_counter - 1 , start_idx ), (filename , class_to_idx [class_ ]))
184
+ videos .append (item )
185
+
186
+ sleep (0.5 ) # allows for progress bar completion
187
+ # update the attributes with the altered sequence
188
+ self .video_paths = video_paths
189
+ self .videos = videos
190
+ self .frames = frames_counter
191
+ self .frames_per_video = frames_per_video
192
+
146
193
@staticmethod
147
194
def _find_classes (data_path ):
148
195
classes = [d for d in listdir (data_path ) if isdir (join (data_path , d ))]
@@ -151,25 +198,31 @@ def _find_classes(data_path):
151
198
return classes , class_to_idx
152
199
153
200
@staticmethod
154
- def _make_data_set (data_path , classes , class_to_idx ):
201
+ def _find_videos (root , classes ):
202
+ return [join (c , d ) for c in classes for d in listdir (join (root , c ))]
203
+
204
+ @staticmethod
205
+ def _make_data_set (root , video_paths , class_to_idx , init_shuffle ):
155
206
def _is_video_file (filename_ ):
156
207
return any (filename_ .endswith (extension ) for extension in VIDEO_EXTENSIONS )
157
208
209
+ if init_shuffle :
210
+ list_shuffle (video_paths ) # shuffle
211
+
158
212
videos = list ()
159
213
frames_per_video = list ()
160
214
frames_counter = 0
161
- for class_ in tqdm (classes , ncols = 80 ):
162
- class_path = join (data_path , class_ )
163
- for filename in listdir (class_path ):
164
- if _is_video_file (filename ):
165
- video_path = join (class_path , filename )
166
- video_meta = ffprobe (video_path )
167
- start_idx = frames_counter
168
- frames = int (video_meta ['video' ].get ('@nb_frames' ))
169
- frames_per_video .append (frames )
170
- frames_counter += frames
171
- item = ((frames_counter - 1 , start_idx ), (join (class_ , filename ), class_to_idx [class_ ]))
172
- videos .append (item )
215
+ for filename in tqdm (video_paths , ncols = 80 ):
216
+ class_ = filename .split ('/' )[0 ]
217
+ data_path = join (root , filename )
218
+ if _is_video_file (data_path ):
219
+ video_meta = ffprobe (data_path )
220
+ start_idx = frames_counter
221
+ frames = int (video_meta ['video' ].get ('@nb_frames' ))
222
+ frames_per_video .append (frames )
223
+ frames_counter += frames
224
+ item = ((frames_counter - 1 , start_idx ), (filename , class_to_idx [class_ ]))
225
+ videos .append (item )
173
226
174
227
sleep (0.5 ) # allows for progress bar completion
175
228
return videos , frames_counter , frames_per_video
0 commit comments