-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Fast seek implementation #3179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast seek implementation #3179
Changes from 9 commits
97f2966
9c4e030
2325598
e6543c3
783cb74
8af53d4
14c22c7
3760271
ef6e129
0adb856
87e32e2
e337654
74a4902
6408e8c
7e44d68
5a0a93a
130dde4
7877777
f5298f9
aaca9bd
0bd0c70
23fdbf0
283a395
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -244,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None): | |
video_frames = torch.empty(0) | ||
frames = [] | ||
video_pts = [] | ||
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object): | ||
if frame['pts'] < s: | ||
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object): | ||
if frame["pts"] < s: | ||
continue | ||
frames.append(frame['data']) | ||
video_pts.append(frame['pts']) | ||
frames.append(frame["data"]) | ||
video_pts.append(frame["pts"]) | ||
if len(frames) > 0: | ||
video_frames = torch.stack(frames, 0) | ||
|
||
|
@@ -257,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None): | |
audio_frames = torch.empty(0) | ||
frames = [] | ||
audio_pts = [] | ||
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object): | ||
if frame['pts'] < s: | ||
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object): | ||
if frame["pts"] < s: | ||
continue | ||
frames.append(frame['data']) | ||
audio_pts.append(frame['pts']) | ||
frames.append(frame["data"]) | ||
audio_pts.append(frame["pts"]) | ||
if len(frames) > 0: | ||
audio_frames = torch.stack(frames, 0) | ||
|
||
|
@@ -294,7 +294,7 @@ def test_read_video_tensor(self): | |
reader = VideoReader(full_path, "video") | ||
frames = [] | ||
for frame in reader: | ||
frames.append(frame['data']) | ||
frames.append(frame["data"]) | ||
new_api = torch.stack(frames, 0) | ||
self.assertEqual(tv_result.size(), new_api.size()) | ||
|
||
|
@@ -402,6 +402,45 @@ def test_video_reading_fn(self): | |
).item() | ||
self.assertEqual(is_same, True) | ||
|
||
@unittest.skipIf(av is None, "PyAV unavailable") | ||
def test_keyframe_reading(self): | ||
for test_video, config in test_videos.items(): | ||
full_path = os.path.join(VIDEO_DIR, test_video) | ||
|
||
av_reader = av.open(full_path) | ||
# reduce streams to only keyframes | ||
av_stream = av_reader.streams.video[0] | ||
av_stream.codec_context.skip_frame = "NONKEY" | ||
|
||
av_keyframes = [] | ||
vr_keyframes = [] | ||
if av_reader.streams.video: | ||
|
||
# get all keyframes using pyav. Then, seek randomly into video reader | ||
|
||
# and assert that all the returned values are in AV_KEYFRAMES | ||
|
||
for av_frame in av_reader.decode(av_stream): | ||
av_keyframes.append(float(av_frame.pts * av_frame.time_base)) | ||
|
||
if len(av_keyframes) > 1: | ||
video_reader = VideoReader(full_path, "video") | ||
for i in range(1, len(av_keyframes)): | ||
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2 | ||
data = next(video_reader.seek(seek_val, True)) | ||
vr_keyframes.append(data["pts"]) | ||
|
||
data = next(video_reader.seek(config.duration, True)) | ||
vr_keyframes.append(data["pts"]) | ||
|
||
self.assertTrue(len(av_keyframes) == len(vr_keyframes)) | ||
# NOTE: this video gets different keyframe with different | ||
# loaders (0.333 pyav, 0.666 for us) | ||
|
||
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi": | ||
for i in range(len(av_keyframes)): | ||
self.assertAlmostEqual( | ||
av_keyframes[i], vr_keyframes[i], delta=0.001 | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.