Skip to content

Allow num_frames and duration to be absent in C++ decoder #708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,16 +602,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
const auto& streamInfo = streamInfos_[activeStreamIndex_];
int64_t numFrames = getNumFrames(streamMetadata);
TORCH_CHECK(
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
TORCH_CHECK(
stop <= numFrames,
"Range stop, " + std::to_string(stop) +
", is more than the number of frames, " + std::to_string(numFrames));
TORCH_CHECK(
step > 0, "Step must be greater than 0; is " + std::to_string(step));

// Note that if we do not have the number of frames available in our metadata,
// then we assume that the upper part of the range is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
if (numFrames.has_value()) {
TORCH_CHECK(
stop <= numFrames.value(),
"Range stop, " + std::to_string(stop) +
", is more than the number of frames, " +
std::to_string(numFrames.value()));
}

int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
FrameBatchOutput frameBatchOutput(
Expand Down Expand Up @@ -676,7 +682,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
containerMetadata_.allStreamMetadata[activeStreamIndex_];

double minSeconds = getMinSeconds(streamMetadata);
double maxSeconds = getMaxSeconds(streamMetadata);
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);

// The frame played at timestamp t and the one played at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
Expand All @@ -687,10 +693,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
for (size_t i = 0; i < timestamps.size(); ++i) {
auto frameSeconds = timestamps[i];
TORCH_CHECK(
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
frameSeconds >= minSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
"; must be greater than or equal to " + std::to_string(minSeconds) +
".");

// Note that if we can't determine the maximum number of seconds from the
// metadata, then we assume the frame's pts is valid.
if (maxSeconds.has_value()) {
TORCH_CHECK(
frameSeconds < maxSeconds.value(),
"frame pts is " + std::to_string(frameSeconds) +
"; must be less than " + std::to_string(maxSeconds.value()) +
".");
}

frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
}
Expand Down Expand Up @@ -737,17 +753,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
}

double minSeconds = getMinSeconds(streamMetadata);
double maxSeconds = getMaxSeconds(streamMetadata);
TORCH_CHECK(
startSeconds >= minSeconds && startSeconds < maxSeconds,
startSeconds >= minSeconds,
"Start seconds is " + std::to_string(startSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
TORCH_CHECK(
stopSeconds <= maxSeconds,
"Stop seconds (" + std::to_string(stopSeconds) +
"; must be less than or equal to " + std::to_string(maxSeconds) +
").");
"; must be greater than or equal to " + std::to_string(minSeconds) +
".");

// Note that if we can't determine the maximum seconds from the metadata, then
// we assume upper range is valid.
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
if (maxSeconds.has_value()) {
TORCH_CHECK(
startSeconds < maxSeconds.value(),
"Start seconds is " + std::to_string(startSeconds) +
"; must be less than " + std::to_string(maxSeconds.value()) + ".");
TORCH_CHECK(
stopSeconds <= maxSeconds.value(),
"Stop seconds (" + std::to_string(stopSeconds) +
"; must be less than or equal to " +
std::to_string(maxSeconds.value()) + ").");
}

// Note that we look at nextPts for a frame, and not its pts or duration.
// Our abstract player displays frames starting at the pts for that frame
Expand Down Expand Up @@ -1456,16 +1481,13 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------

int64_t SingleStreamDecoder::getNumFrames(
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.numFramesFromScan.value();
case SeekMode::approximate: {
TORCH_CHECK(
streamMetadata.numFrames.has_value(),
"Cannot use approximate mode since we couldn't find the number of frames from the metadata.");
return streamMetadata.numFrames.value();
return streamMetadata.numFrames;
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand All @@ -1484,16 +1506,13 @@ double SingleStreamDecoder::getMinSeconds(
}
}

double SingleStreamDecoder::getMaxSeconds(
std::optional<double> SingleStreamDecoder::getMaxSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.maxPtsSecondsFromScan.value();
case SeekMode::approximate: {
TORCH_CHECK(
streamMetadata.durationSeconds.has_value(),
"Cannot use approximate mode since we couldn't find the duration from the metadata.");
return streamMetadata.durationSeconds.value();
return streamMetadata.durationSeconds;
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand Down Expand Up @@ -1539,12 +1558,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
void SingleStreamDecoder::validateFrameIndex(
const StreamMetadata& streamMetadata,
int64_t frameIndex) {
int64_t numFrames = getNumFrames(streamMetadata);
TORCH_CHECK(
frameIndex >= 0 && frameIndex < numFrames,
frameIndex >= 0,
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
" numFrames=" + std::to_string(numFrames));
"; must be greater than or equal to 0");

// Note that if we do not have the number of frames available in our metadata,
// then we assume that the frameIndex is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
if (numFrames.has_value()) {
TORCH_CHECK(
frameIndex < numFrames.value(),
"Invalid frame index=" + std::to_string(frameIndex) +
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
"; must be less than " + std::to_string(numFrames.value()));
}
}

// --------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ class SingleStreamDecoder {
// index. Note that this index may be truncated for some files.
int getBestStreamIndex(AVMediaType mediaType);

int64_t getNumFrames(const StreamMetadata& streamMetadata);
std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
double getMinSeconds(const StreamMetadata& streamMetadata);
double getMaxSeconds(const StreamMetadata& streamMetadata);
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);

// --------------------------------------------------------------------------
// VALIDATION UTILS
Expand Down
4 changes: 2 additions & 2 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,10 @@ def test_get_frames_played_at(self, device, seek_mode):
def test_get_frames_played_at_fails(self, device, seek_mode):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)

with pytest.raises(RuntimeError, match="must be in range"):
with pytest.raises(RuntimeError, match="must be greater than or equal to"):
decoder.get_frames_played_at([-1])

with pytest.raises(RuntimeError, match="must be in range"):
with pytest.raises(RuntimeError, match="must be less than"):
decoder.get_frames_played_at([14])

with pytest.raises(RuntimeError, match="Expected a value of type"):
Expand Down
Loading