-
Notifications
You must be signed in to change notification settings - Fork 46
Add stream_index seek mode, read frame index and update metadata #764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88cacb9
4dfc581
ed3fdec
6030f9e
1361c0d
5bb23c2
6573943
3341dd8
e914160
394ac70
f13433c
8971bd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -319,6 +319,42 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { | |||||||||||||||||||||||||||||
scannedAllStreams_ = true; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex( | ||||||||||||||||||||||||||||||
int streamIndex, | ||||||||||||||||||||||||||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> customFrameMappings) { | ||||||||||||||||||||||||||||||
auto& all_frames = std::get<0>(customFrameMappings); | ||||||||||||||||||||||||||||||
auto& is_key_frame = std::get<1>(customFrameMappings); | ||||||||||||||||||||||||||||||
auto& duration = std::get<2>(customFrameMappings); | ||||||||||||||||||||||||||||||
TORCH_CHECK( | ||||||||||||||||||||||||||||||
all_frames.size(0) == is_key_frame.size(0) && | ||||||||||||||||||||||||||||||
is_key_frame.size(0) == duration.size(0), | ||||||||||||||||||||||||||||||
"all_frames, is_key_frame, and duration from custom_frame_mappings were not same size."); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
streamMetadata.beginStreamPtsFromContent = all_frames[0].item<int64_t>(); | ||||||||||||||||||||||||||||||
streamMetadata.endStreamPtsFromContent = | ||||||||||||||||||||||||||||||
all_frames[-1].item<int64_t>() + duration[-1].item<int64_t>(); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
auto avStream = formatContext_->streams[streamIndex]; | ||||||||||||||||||||||||||||||
streamMetadata.beginStreamPtsSecondsFromContent = | ||||||||||||||||||||||||||||||
*streamMetadata.beginStreamPtsFromContent * av_q2d(avStream->time_base); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
streamMetadata.endStreamPtsSecondsFromContent = | ||||||||||||||||||||||||||||||
*streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably use the |
||||||||||||||||||||||||||||||
streamMetadata.numFramesFromContent = all_frames.size(0); | ||||||||||||||||||||||||||||||
for (int64_t i = 0; i < all_frames.size(0); ++i) { | ||||||||||||||||||||||||||||||
// FrameInfo struct utilizes PTS | ||||||||||||||||||||||||||||||
FrameInfo frameInfo = {all_frames[i].item<int64_t>()}; | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not very familiar with that kind of initialization, but it's not immediately obvious that it's setting the |
||||||||||||||||||||||||||||||
frameInfo.isKeyFrame = (is_key_frame[i].item<bool>() == true); | ||||||||||||||||||||||||||||||
frameInfo.nextPts = (i + 1 < all_frames.size(0)) | ||||||||||||||||||||||||||||||
? all_frames[i + 1].item<int64_t>() | ||||||||||||||||||||||||||||||
: INT64_MAX; | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also set the |
||||||||||||||||||||||||||||||
streamInfos_[streamIndex].allFrames.push_back(frameInfo); | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also make sure to update the |
||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize reading this that there is an additional design decision we'll have to make: whether we expect the index to be already sorted, or not. In the existing torchcodec/src/torchcodec/_core/SingleStreamDecoder.cpp Lines 285 to 298 in 86e952f
I suspect that frame mappings coming from
I think the simpler, safer choice is to sort in C++ and rely on the same sorting logic that we have in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on sorting on the C++ side; we can do it quickly on the C++ side, and it's much nicer to users. We should extract the existing logic out to a utility function called only in this cpp file, and call it in both places. I think we sort in the scan function because I think it's possible for the actual packets to be not in PTS order. |
||||||||||||||||||||||||||||||
ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { | ||||||||||||||||||||||||||||||
return containerMetadata_; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
@@ -431,7 +467,9 @@ void SingleStreamDecoder::addStream( | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
void SingleStreamDecoder::addVideoStream( | ||||||||||||||||||||||||||||||
int streamIndex, | ||||||||||||||||||||||||||||||
const VideoStreamOptions& videoStreamOptions) { | ||||||||||||||||||||||||||||||
const VideoStreamOptions& videoStreamOptions, | ||||||||||||||||||||||||||||||
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>> | ||||||||||||||||||||||||||||||
customFrameMappings) { | ||||||||||||||||||||||||||||||
addStream( | ||||||||||||||||||||||||||||||
streamIndex, | ||||||||||||||||||||||||||||||
AVMEDIA_TYPE_VIDEO, | ||||||||||||||||||||||||||||||
|
@@ -456,6 +494,14 @@ void SingleStreamDecoder::addVideoStream( | |||||||||||||||||||||||||||||
streamMetadata.height = streamInfo.codecContext->height; | ||||||||||||||||||||||||||||||
streamMetadata.sampleAspectRatio = | ||||||||||||||||||||||||||||||
streamInfo.codecContext->sample_aspect_ratio; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if (seekMode_ == SeekMode::custom_frame_mappings) { | ||||||||||||||||||||||||||||||
TORCH_CHECK( | ||||||||||||||||||||||||||||||
customFrameMappings.has_value(), | ||||||||||||||||||||||||||||||
"Please provide frame mappings when using custom_frame_mappings seek mode."); | ||||||||||||||||||||||||||||||
readCustomFrameMappingsUpdateMetadataAndIndex( | ||||||||||||||||||||||||||||||
streamIndex, customFrameMappings.value()); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
void SingleStreamDecoder::addAudioStream( | ||||||||||||||||||||||||||||||
|
@@ -1407,6 +1453,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( | |||||||||||||||||||||||||||||
int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { | ||||||||||||||||||||||||||||||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: { | ||||||||||||||||||||||||||||||
auto frame = std::lower_bound( | ||||||||||||||||||||||||||||||
streamInfo.allFrames.begin(), | ||||||||||||||||||||||||||||||
|
@@ -1434,6 +1481,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { | |||||||||||||||||||||||||||||
int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { | ||||||||||||||||||||||||||||||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: { | ||||||||||||||||||||||||||||||
auto frame = std::upper_bound( | ||||||||||||||||||||||||||||||
streamInfo.allFrames.begin(), | ||||||||||||||||||||||||||||||
|
@@ -1461,6 +1509,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { | |||||||||||||||||||||||||||||
int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { | ||||||||||||||||||||||||||||||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: | ||||||||||||||||||||||||||||||
return streamInfo.allFrames[frameIndex].pts; | ||||||||||||||||||||||||||||||
case SeekMode::approximate: { | ||||||||||||||||||||||||||||||
|
@@ -1485,6 +1534,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { | |||||||||||||||||||||||||||||
std::optional<int64_t> SingleStreamDecoder::getNumFrames( | ||||||||||||||||||||||||||||||
const StreamMetadata& streamMetadata) { | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: | ||||||||||||||||||||||||||||||
return streamMetadata.numFramesFromContent.value(); | ||||||||||||||||||||||||||||||
case SeekMode::approximate: { | ||||||||||||||||||||||||||||||
|
@@ -1498,6 +1548,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames( | |||||||||||||||||||||||||||||
double SingleStreamDecoder::getMinSeconds( | ||||||||||||||||||||||||||||||
const StreamMetadata& streamMetadata) { | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: | ||||||||||||||||||||||||||||||
return streamMetadata.beginStreamPtsSecondsFromContent.value(); | ||||||||||||||||||||||||||||||
case SeekMode::approximate: | ||||||||||||||||||||||||||||||
|
@@ -1510,6 +1561,7 @@ double SingleStreamDecoder::getMinSeconds( | |||||||||||||||||||||||||||||
std::optional<double> SingleStreamDecoder::getMaxSeconds( | ||||||||||||||||||||||||||||||
const StreamMetadata& streamMetadata) { | ||||||||||||||||||||||||||||||
switch (seekMode_) { | ||||||||||||||||||||||||||||||
case SeekMode::custom_frame_mappings: | ||||||||||||||||||||||||||||||
case SeekMode::exact: | ||||||||||||||||||||||||||||||
return streamMetadata.endStreamPtsSecondsFromContent.value(); | ||||||||||||||||||||||||||||||
case SeekMode::approximate: { | ||||||||||||||||||||||||||||||
|
@@ -1645,6 +1697,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { | |||||||||||||||||||||||||||||
return SingleStreamDecoder::SeekMode::exact; | ||||||||||||||||||||||||||||||
} else if (seekMode == "approximate") { | ||||||||||||||||||||||||||||||
return SingleStreamDecoder::SeekMode::approximate; | ||||||||||||||||||||||||||||||
} else if (seekMode == "custom_frame_mappings") { | ||||||||||||||||||||||||||||||
return SingleStreamDecoder::SeekMode::custom_frame_mappings; | ||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -29,7 +29,7 @@ class SingleStreamDecoder { | |||||||
// CONSTRUCTION API | ||||||||
// -------------------------------------------------------------------------- | ||||||||
|
||||||||
enum class SeekMode { exact, approximate }; | ||||||||
enum class SeekMode { exact, approximate, custom_frame_mappings }; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if we'll want to publicly expose a new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. Right now, the torchcodec/src/torchcodec/_core/SingleStreamDecoder.cpp Lines 183 to 185 in d444567
Since the point of this feature is to avoid a file scan while still maintaining seek accuracy, I think it makes sense for it to be a new seek mode on the C++ side. We can figure out what to do with the public API later. |
||||||||
|
||||||||
// Creates a SingleStreamDecoder from the video at videoFilePath. | ||||||||
explicit SingleStreamDecoder( | ||||||||
|
@@ -53,6 +53,13 @@ class SingleStreamDecoder { | |||||||
// the allFrames and keyFrames vectors. | ||||||||
void scanFileAndUpdateMetadataAndIndex(); | ||||||||
|
||||||||
// Reads the user provided frame index and updates each StreamInfo's index, | ||||||||
// i.e. the allFrames and keyFrames vectors, and | ||||||||
// endStreamPtsSecondsFromContent | ||||||||
void readCustomFrameMappingsUpdateMetadataAndIndex( | ||||||||
int streamIndex, | ||||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> customFrameMappings); | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just bumping #764 (comment) again, which may have been missed:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To elaborate on defining a new As a point of comparison, we do something similar with batched output. |
||||||||
// Returns the metadata for the container. | ||||||||
ContainerMetadata getContainerMetadata() const; | ||||||||
|
||||||||
|
@@ -66,7 +73,9 @@ class SingleStreamDecoder { | |||||||
|
||||||||
void addVideoStream( | ||||||||
int streamIndex, | ||||||||
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); | ||||||||
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), | ||||||||
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>> | ||||||||
customFrameMappings = std::nullopt); | ||||||||
void addAudioStream( | ||||||||
int streamIndex, | ||||||||
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,23 +31,26 @@ def _get_container_metadata(path, seek_mode): | |
return get_container_metadata(decoder) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"metadata_getter", | ||
( | ||
get_container_metadata_from_header, | ||
functools.partial(_get_container_metadata, seek_mode="approximate"), | ||
functools.partial(_get_container_metadata, seek_mode="exact"), | ||
), | ||
) | ||
def test_get_metadata(metadata_getter): | ||
with_scan = ( | ||
metadata_getter.keywords["seek_mode"] == "exact" | ||
if isinstance(metadata_getter, functools.partial) | ||
else False | ||
@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "custom_frame_mappings"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you may have removed |
||
def test_get_metadata(seek_mode): | ||
from torchcodec._core import add_video_stream | ||
|
||
decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode) | ||
# For custom_frame_mappings seek mode, add a video stream to update metadata | ||
custom_frame_mappings = ( | ||
NASA_VIDEO.get_custom_frame_mappings() | ||
if seek_mode == "custom_frame_mappings" | ||
else None | ||
) | ||
# Add the best video stream (index 3 for NASA_VIDEO) | ||
add_video_stream( | ||
decoder, | ||
stream_index=NASA_VIDEO.default_stream_index, | ||
custom_frame_mappings=custom_frame_mappings, | ||
) | ||
metadata = get_container_metadata(decoder) | ||
|
||
metadata = metadata_getter(NASA_VIDEO.path) | ||
# metadata = metadata_getter(NASA_VIDEO.path) | ||
with_scan = seek_mode == "exact" or seek_mode == "custom_frame_mappings" | ||
|
||
assert len(metadata.streams) == 6 | ||
assert metadata.best_video_stream_index == 3 | ||
|
@@ -82,7 +85,7 @@ def test_get_metadata(metadata_getter): | |
assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 | ||
assert best_video_stream_metadata.bit_rate == 128783 | ||
assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) | ||
assert best_video_stream_metadata.pixel_aspect_ratio is None | ||
assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
assert best_video_stream_metadata.codec == "h264" | ||
assert best_video_stream_metadata.num_frames_from_content == ( | ||
390 if with_scan else None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL that tensors on the C++ side also support negative indices! I'm so used to that being not allowed in C++ arrays and standard containers that I initially thought this was undefined behavior!