Skip to content

Add stream_index seek mode, read frame index and update metadata #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
56 changes: 55 additions & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,42 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
scannedAllStreams_ = true;
}

void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
int streamIndex,
std::tuple<at::Tensor, at::Tensor, at::Tensor> customFrameMappings) {
auto& all_frames = std::get<0>(customFrameMappings);
auto& is_key_frame = std::get<1>(customFrameMappings);
auto& duration = std::get<2>(customFrameMappings);
TORCH_CHECK(
all_frames.size(0) == is_key_frame.size(0) &&
is_key_frame.size(0) == duration.size(0),
"all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.");

auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];

streamMetadata.beginStreamPtsFromContent = all_frames[0].item<int64_t>();
streamMetadata.endStreamPtsFromContent =
all_frames[-1].item<int64_t>() + duration[-1].item<int64_t>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL that tensors on the C++ side also support negative indices! I'm so used to that being not allowed in C++ arrays and standard containers that I initially thought this was undefined behavior!


auto avStream = formatContext_->streams[streamIndex];
streamMetadata.beginStreamPtsSecondsFromContent =
*streamMetadata.beginStreamPtsFromContent * av_q2d(avStream->time_base);

streamMetadata.endStreamPtsSecondsFromContent =
*streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably use the ptsToSeconds() function here, but it looks like we're also not using in the scanning function. And we can probably better define ptsToSeconds() to use av_q2d(). Let's address that elsewhere; created #770 to track.

streamMetadata.numFramesFromContent = all_frames.size(0);
for (int64_t i = 0; i < all_frames.size(0); ++i) {
// FrameInfo struct utilizes PTS
FrameInfo frameInfo = {all_frames[i].item<int64_t>()};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with that kind of initialization, but it's not immediately obvious that it's setting the frameInfo.pts field. Let's set the field explicitly instead, for clarity.

frameInfo.isKeyFrame = (is_key_frame[i].item<bool>() == true);
frameInfo.nextPts = (i + 1 < all_frames.size(0))
? all_frames[i + 1].item<int64_t>()
: INT64_MAX;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also set the frameInfo's frameIndex field. I think it should be i?
EDIT: it should actually be set after we are sure the sequence is sorted, see other comment below.

streamInfos_[streamIndex].allFrames.push_back(frameInfo);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also make sure to update the streamInfos_[streamIndex].keyFrames index, if isKeyFrame is true!

}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize reading this that there is an additional design decision we'll have to make: whether we expect the index to be already sorted, or not.

In the existing scanFileAndUpdateMetadataAndIndex() function we are reading packets in order, and yet we are sorting them afterwards:

// Sort all frames by their pts.
for (auto& [streamIndex, streamInfo] : streamInfos_) {
std::sort(
streamInfo.keyFrames.begin(),
streamInfo.keyFrames.end(),
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
return frameInfo1.pts < frameInfo2.pts;
});
std::sort(
streamInfo.allFrames.begin(),
streamInfo.allFrames.end(),
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
return frameInfo1.pts < frameInfo2.pts;
});

I suspect that frame mappings coming from ffprobe won't be ordered in general. I think we have the following options:

  • expect the input mapping to be sorted - that may not be a great UX
  • sort the mapping in Python - this is kinda what this PR is doing, by sorting the mappings in the tests - but we should remove that and have that logic within the code, not the tests
  • sort the mappings in C++

I think the simpler, safer choice is to sort in C++ and rely on the same sorting logic that we have in scanFileAndUpdateMetadataAndIndex(). Curious what your thoughts are @Dan-Flores @scotts ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on sorting on the C++ side; we can do it quickly on the C++ side, and it's much nicer to users. We should extract the existing logic out to a utility function called only in this cpp file, and call it in both places.

I think we sort in the scan function because I think it's possible for the actual packets to be not in PTS order.

ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
return containerMetadata_;
}
Expand Down Expand Up @@ -431,7 +467,9 @@ void SingleStreamDecoder::addStream(

void SingleStreamDecoder::addVideoStream(
int streamIndex,
const VideoStreamOptions& videoStreamOptions) {
const VideoStreamOptions& videoStreamOptions,
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
customFrameMappings) {
addStream(
streamIndex,
AVMEDIA_TYPE_VIDEO,
Expand All @@ -456,6 +494,14 @@ void SingleStreamDecoder::addVideoStream(
streamMetadata.height = streamInfo.codecContext->height;
streamMetadata.sampleAspectRatio =
streamInfo.codecContext->sample_aspect_ratio;

if (seekMode_ == SeekMode::custom_frame_mappings) {
TORCH_CHECK(
customFrameMappings.has_value(),
"Please provide frame mappings when using custom_frame_mappings seek mode.");
readCustomFrameMappingsUpdateMetadataAndIndex(
streamIndex, customFrameMappings.value());
}
}

void SingleStreamDecoder::addAudioStream(
Expand Down Expand Up @@ -1407,6 +1453,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
auto& streamInfo = streamInfos_[activeStreamIndex_];
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact: {
auto frame = std::lower_bound(
streamInfo.allFrames.begin(),
Expand Down Expand Up @@ -1434,6 +1481,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
auto& streamInfo = streamInfos_[activeStreamIndex_];
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact: {
auto frame = std::upper_bound(
streamInfo.allFrames.begin(),
Expand Down Expand Up @@ -1461,6 +1509,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
auto& streamInfo = streamInfos_[activeStreamIndex_];
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamInfo.allFrames[frameIndex].pts;
case SeekMode::approximate: {
Expand All @@ -1485,6 +1534,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.numFramesFromContent.value();
case SeekMode::approximate: {
Expand All @@ -1498,6 +1548,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
double SingleStreamDecoder::getMinSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
Expand All @@ -1510,6 +1561,7 @@ double SingleStreamDecoder::getMinSeconds(
std::optional<double> SingleStreamDecoder::getMaxSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.endStreamPtsSecondsFromContent.value();
case SeekMode::approximate: {
Expand Down Expand Up @@ -1645,6 +1697,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
return SingleStreamDecoder::SeekMode::exact;
} else if (seekMode == "approximate") {
return SingleStreamDecoder::SeekMode::approximate;
} else if (seekMode == "custom_frame_mappings") {
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
} else {
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
}
Expand Down
13 changes: 11 additions & 2 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SingleStreamDecoder {
// CONSTRUCTION API
// --------------------------------------------------------------------------

enum class SeekMode { exact, approximate };
enum class SeekMode { exact, approximate, custom_frame_mappings };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we'll want to publicly expose a new seek_mode in Python, but I think for the C++ side this a reasonable approach. @scotts curious if you have any opinion on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Right now, the SingleStreamDecoder's seekMode_ is what we use to determine if we should do a file scan or not:

if (seekMode_ == SeekMode::exact) {
scanFileAndUpdateMetadataAndIndex();
}

Since the point of this feature is to avoid a file scan while still maintaining seek accuracy, I think it makes sense for it to be a new seek mode on the C++ side. We can figure out what to do with the public API later.


// Creates a SingleStreamDecoder from the video at videoFilePath.
explicit SingleStreamDecoder(
Expand All @@ -53,6 +53,13 @@ class SingleStreamDecoder {
// the allFrames and keyFrames vectors.
void scanFileAndUpdateMetadataAndIndex();

// Reads the user provided frame index and updates each StreamInfo's index,
// i.e. the allFrames and keyFrames vectors, and
// endStreamPtsSecondsFromContent
void readCustomFrameMappingsUpdateMetadataAndIndex(
int streamIndex,
std::tuple<at::Tensor, at::Tensor, at::Tensor> customFrameMappings);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just bumping #764 (comment) again, which may have been missed:

  • we should document what are the expected length, dtype, and associated semantic of each tensor.
  • I'd also recommend relying on a new FrameMappings struct instead of a 3-tuple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To elaborate on defining a new FrameMappings struct: I think of the code in custom_ops.cpp as the bridge layer between the C++ logic and the Python logic. So that's the layer where we would turn a tuple of tensors into a struct. That way, in the C++ logic, we're (as much as possible) operating on proper types with meaningful field names.

As a point of comparison, we do something similar with batched output. SingleStreamDecoder returns a FrameBatchOutput struct, and the code in custom_ops.cpp turns that into a tuple of tensors.

// Returns the metadata for the container.
ContainerMetadata getContainerMetadata() const;

Expand All @@ -66,7 +73,9 @@ class SingleStreamDecoder {

void addVideoStream(
int streamIndex,
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(),
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
customFrameMappings = std::nullopt);
void addAudioStream(
int streamIndex,
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
Expand Down
17 changes: 11 additions & 6 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
m.def(
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()");
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
m.def(
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
m.def(
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
Expand Down Expand Up @@ -223,6 +223,8 @@ void _add_video_stream(
std::optional<std::string_view> dimension_order = std::nullopt,
std::optional<int64_t> stream_index = std::nullopt,
std::optional<std::string_view> device = std::nullopt,
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
custom_frame_mappings = std::nullopt,
std::optional<std::string_view> color_conversion_library = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.width = width;
Expand Down Expand Up @@ -253,9 +255,9 @@ void _add_video_stream(
if (device.has_value()) {
videoStreamOptions.device = createTorchDevice(std::string(device.value()));
}

auto videoDecoder = unwrapTensorToGetDecoder(decoder);
videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions);
videoDecoder->addVideoStream(
stream_index.value_or(-1), videoStreamOptions, custom_frame_mappings);
}

// Add a new video stream at `stream_index` using the provided options.
Expand All @@ -266,15 +268,18 @@ void add_video_stream(
std::optional<int64_t> num_threads = std::nullopt,
std::optional<std::string_view> dimension_order = std::nullopt,
std::optional<int64_t> stream_index = std::nullopt,
std::optional<std::string_view> device = std::nullopt) {
std::optional<std::string_view> device = std::nullopt,
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
custom_frame_mappings = std::nullopt) {
_add_video_stream(
decoder,
width,
height,
num_threads,
dimension_order,
stream_index,
device);
device,
custom_frame_mappings);
}

void add_audio_stream(
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def _add_video_stream_abstract(
dimension_order: Optional[str] = None,
stream_index: Optional[int] = None,
device: Optional[str] = None,
custom_frame_mappings: Optional[
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
] = None,
color_conversion_library: Optional[str] = None,
) -> None:
return
Expand All @@ -220,6 +223,9 @@ def add_video_stream_abstract(
dimension_order: Optional[str] = None,
stream_index: Optional[int] = None,
device: Optional[str] = None,
custom_frame_mappings: Optional[
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
] = None,
) -> None:
return

Expand Down
35 changes: 19 additions & 16 deletions test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,26 @@ def _get_container_metadata(path, seek_mode):
return get_container_metadata(decoder)


@pytest.mark.parametrize(
"metadata_getter",
(
get_container_metadata_from_header,
functools.partial(_get_container_metadata, seek_mode="approximate"),
functools.partial(_get_container_metadata, seek_mode="exact"),
),
)
def test_get_metadata(metadata_getter):
with_scan = (
metadata_getter.keywords["seek_mode"] == "exact"
if isinstance(metadata_getter, functools.partial)
else False
@pytest.mark.parametrize("seek_mode", ["approximate", "exact", "custom_frame_mappings"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may have removed get_container_metadata_from_header from the previous parametrization

def test_get_metadata(seek_mode):
from torchcodec._core import add_video_stream

decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode)
# For custom_frame_mappings seek mode, add a video stream to update metadata
custom_frame_mappings = (
NASA_VIDEO.get_custom_frame_mappings()
if seek_mode == "custom_frame_mappings"
else None
)
# Add the best video stream (index 3 for NASA_VIDEO)
add_video_stream(
decoder,
stream_index=NASA_VIDEO.default_stream_index,
custom_frame_mappings=custom_frame_mappings,
)
metadata = get_container_metadata(decoder)

metadata = metadata_getter(NASA_VIDEO.path)
# metadata = metadata_getter(NASA_VIDEO.path)
with_scan = seek_mode == "exact" or seek_mode == "custom_frame_mappings"

assert len(metadata.streams) == 6
assert metadata.best_video_stream_index == 3
Expand Down Expand Up @@ -82,7 +85,7 @@ def test_get_metadata(metadata_getter):
assert best_video_stream_metadata.begin_stream_seconds_from_header == 0
assert best_video_stream_metadata.bit_rate == 128783
assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001)
assert best_video_stream_metadata.pixel_aspect_ratio is None
assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since add_video_stream is always being called, this value is set instead of None.

assert best_video_stream_metadata.codec == "h264"
assert best_video_stream_metadata.num_frames_from_content == (
390 if with_scan else None
Expand Down
63 changes: 63 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,69 @@ def test_frame_pts_equality(self):
)
assert pts_is_equal

def test_seek_mode_custom_frame_mappings_fails(self):
decoder = create_from_file(
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
)
with pytest.raises(
RuntimeError,
match="Please provide frame mappings when using custom_frame_mappings seek mode.",
):
add_video_stream(decoder, stream_index=0, custom_frame_mappings=None)

decoder = create_from_file(
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
)
different_lengths = (
torch.tensor([1, 2, 3]),
torch.tensor([1, 2]),
torch.tensor([1, 2, 3]),
)
with pytest.raises(
RuntimeError,
match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.",
):
add_video_stream(
decoder, stream_index=0, custom_frame_mappings=different_lengths
)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_seek_mode_custom_frame_mappings(self, device):
stream_index = 3 # frame index seek mode requires a stream index
decoder = create_from_file(
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
)
add_video_stream(
decoder,
device=device,
stream_index=stream_index,
custom_frame_mappings=NASA_VIDEO.get_custom_frame_mappings(
stream_index=stream_index
),
)

frame0, _, _ = get_next_frame(decoder)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(
0, stream_index=stream_index
)
assert_frames_equal(frame0, reference_frame0.to(device))

frame6, _, _ = get_frame_at_pts(decoder, 6.006)
reference_frame6 = NASA_VIDEO.get_frame_data_by_index(
INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index
)
assert_frames_equal(frame6, reference_frame6.to(device))

frame6, _, _ = get_frame_at_index(decoder, frame_index=180)
reference_frame6 = NASA_VIDEO.get_frame_data_by_index(
INDEX_OF_FRAME_AT_6_SECONDS, stream_index=stream_index
)
assert_frames_equal(frame6, reference_frame6.to(device))

ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9)
bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9)
assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device))

@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
def test_color_conversion_library(self, color_conversion_library):
decoder = create_from_file(str(NASA_VIDEO.path))
Expand Down
Loading
Loading