Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
97f2966
modify processPacket to support fast seek
bjuncek Nov 17, 2020
9c4e030
add fastSeek to ProcessPacket decoder definition
bjuncek Nov 17, 2020
2325598
add fastseek flag to DecoderParametersStruct
bjuncek Nov 17, 2020
e6543c3
add fastseek flag to the process packet call
bjuncek Nov 17, 2020
783cb74
no default params in C++ implementation
bjuncek Nov 17, 2020
8af53d4
enable flag in C++ implementation
bjuncek Nov 17, 2020
14c22c7
make order of parameters more normal
bjuncek Nov 17, 2020
3760271
register new seek with python api
bjuncek Nov 17, 2020
ef6e129
[somewhat broken] test suite for keyframes using pyav
bjuncek Dec 1, 2020
0adb856
revert " changes
bjuncek Sep 2, 2021
87e32e2
add type annotations to init
bjuncek Sep 2, 2021
e337654
Merge remote-tracking branch 'upstream/main' into bkorbar/videoapi/se…
bjuncek Sep 23, 2021
74a4902
Adding tests
bjuncek Sep 23, 2021
6408e8c
linter
bjuncek Sep 23, 2021
7e44d68
Flake doesn't show up :|
bjuncek Sep 24, 2021
5a0a93a
Change from unitest to pytest syntax
bjuncek Sep 24, 2021
130dde4
Merge branch 'main' into bkorbar/videoapi/seek_rebase
bjuncek Sep 24, 2021
7877777
Merge branch 'main' into bkorbar/videoapi/seek_rebase
prabhat00155 Sep 27, 2021
f5298f9
add return type
bjuncek Oct 13, 2021
aaca9bd
Resolved merge conflicts
prabhat00155 Nov 3, 2021
0bd0c70
Merge branch 'main' into bkorbar/videoapi/seek_rebase
prabhat00155 Nov 3, 2021
23fdbf0
Merge branch 'main' into bkorbar/videoapi/seek_rebase
prabhat00155 Nov 3, 2021
283a395
Merge branch 'main' into bkorbar/videoapi/seek_rebase
prabhat00155 Nov 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions test/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
video_frames = torch.empty(0)
frames = []
video_pts = []
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
if frame['pts'] < s:
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object):
if frame["pts"] < s:
continue
frames.append(frame['data'])
video_pts.append(frame['pts'])
frames.append(frame["data"])
video_pts.append(frame["pts"])
if len(frames) > 0:
video_frames = torch.stack(frames, 0)

Expand All @@ -257,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
audio_frames = torch.empty(0)
frames = []
audio_pts = []
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
if frame['pts'] < s:
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object):
if frame["pts"] < s:
continue
frames.append(frame['data'])
audio_pts.append(frame['pts'])
frames.append(frame["data"])
audio_pts.append(frame["pts"])
if len(frames) > 0:
audio_frames = torch.stack(frames, 0)

Expand Down Expand Up @@ -294,7 +294,7 @@ def test_read_video_tensor(self):
reader = VideoReader(full_path, "video")
frames = []
for frame in reader:
frames.append(frame['data'])
frames.append(frame["data"])
new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size())

Expand Down Expand Up @@ -402,6 +402,45 @@ def test_video_reading_fn(self):
).item()
self.assertEqual(is_same, True)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_keyframe_reading(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)

av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"

av_keyframes = []
vr_keyframes = []
if av_reader.streams.video:

# get all keyframes using pyav. Then, seek randomly into video reader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should go before the if statement as it describes the entire chunk of code from line 417 onwards, not just the if block.

# and assert that all the returned values are in AV_KEYFRAMES

for av_frame in av_reader.decode(av_stream):
av_keyframes.append(float(av_frame.pts * av_frame.time_base))

if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])

data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])

self.assertTrue(len(av_keyframes) == len(vr_keyframes))
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know the reason for this?

if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
self.assertAlmostEqual(
av_keyframes[i], vr_keyframes[i], delta=0.001
)


if __name__ == "__main__":
unittest.main()
19 changes: 14 additions & 5 deletions torchvision/csrc/io/decoder/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ int Decoder::getFrame(size_t workingTimeInMs) {
bool gotFrame = false;
bool hasMsg = false;
// packet either got consumed completely or not at all
if ((result = processPacket(stream, &avPacket, &gotFrame, &hasMsg)) < 0) {
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " processPacket failed with code=" << result;
if ((result = processPacket(
stream, &avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
LOG(ERROR) << "processPacket failed with code: " << result;
break;
}

Expand Down Expand Up @@ -606,7 +606,8 @@ int Decoder::processPacket(
Stream* stream,
AVPacket* packet,
bool* gotFrame,
bool* hasMsg) {
bool* hasMsg,
bool fastSeek) {
// decode package
int result;
DecoderOutputMessage msg;
Expand All @@ -619,7 +620,15 @@ int Decoder::processPacket(
bool endInRange =
params_.endOffset <= 0 || msg.header.pts <= params_.endOffset;
inRange_.set(stream->getIndex(), endInRange);
if (endInRange && msg.header.pts >= params_.startOffset) {
// if fastseek is enabled, we're returning the first
// frame that we decode after (potential) seek.
// By default, we perform accurate seek to the closest
// following frame
bool startCondition = true;
if (!fastSeek) {
startCondition = msg.header.pts >= params_.startOffset;
}
if (endInRange && startCondition) {
*hasMsg = true;
push(std::move(msg));
}
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/decoder/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class Decoder : public MediaDecoder {
Stream* stream,
AVPacket* packet,
bool* gotFrame,
bool* hasMsg);
bool* hasMsg,
bool fastSeek = false);
void flushStreams();
void cleanUp();

Expand Down
2 changes: 2 additions & 0 deletions torchvision/csrc/io/decoder/defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ struct DecoderParameters {
bool listen{false};
// don't copy frame body, only header
bool headerOnly{false};
// enable fast seek (seek only to keyframes)
bool fastSeek{false};
// interrupt init method on timeout
bool preventStaleness{true};
// seek tolerated accuracy (us)
Expand Down
7 changes: 6 additions & 1 deletion torchvision/csrc/io/video/Video.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,15 @@ void Video::_getDecoderParams(
int64_t getPtsOnly,
std::string stream,
long stream_id = -1,
bool fastSeek = true,
bool all_streams = false,
double seekFrameMarginUs = 10) {
int64_t videoStartUs = int64_t(videoStartS * 1e6);

params.timeoutMs = decoderTimeoutMs;
params.startOffset = videoStartUs;
params.seekAccuracy = seekFrameMarginUs;
params.fastSeek = fastSeek;
params.headerOnly = false;

params.preventStaleness = false; // not sure what this is about
Expand Down Expand Up @@ -161,6 +163,7 @@ Video::Video(std::string videoPath, std::string stream) {
0, // headerOnly
get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2
false, // fastseek: we're using the default param here
true // read all streams
);

Expand Down Expand Up @@ -232,6 +235,7 @@ bool Video::setCurrentStream(std::string stream = "video") {
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
false, // fastseek param set to 0 false by default (changed in seek)
false // read all streams
);

Expand All @@ -248,14 +252,15 @@ c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
return streamsMetadata;
}

void Video::Seek(double ts) {
void Video::Seek(double ts, bool fastSeek = false) {
// initialize the class variables used for seeking and retrurn
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
fastSeek, // fastseek
false // read all streams
);

Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/video/Video.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct Video : torch::CustomClassHolder {
std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const;
void Seek(double ts);
void Seek(double ts, bool fastSeek);
bool setCurrentStream(std::string stream);
std::tuple<torch::Tensor, double> Next();

Expand All @@ -45,6 +45,7 @@ struct Video : torch::CustomClassHolder {
int64_t getPtsOnly,
std::string stream,
long stream_id,
bool fastSeek,
bool all_streams,
double seekFrameMarginUs); // this needs to be improved

Expand Down
5 changes: 3 additions & 2 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,20 @@ def __next__(self):
def __iter__(self):
return self

def seek(self, time_s: float):
def seek(self, time_s: float, keyframes_only=False):
"""Seek within current stream.

Args:
time_s (float): seek time in seconds
keyframes_only (bool): allow to seek only to keyframes

.. note::
Current implementation is the so-called precise seek. This
means following seek, call to :mod:`next()` will return the
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
"""
self._c.seek(time_s)
self._c.seek(time_s, keyframes_only)
return self

def get_metadata(self):
Expand Down