Skip to content

Allow option to use the swscale library for color conversion instead of filtergraph #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
.
  • Loading branch information
ahmadsharif1 committed Sep 24, 2024
commit c89959668115dc0fc3c2b40bec4f44f8d160c39c
14 changes: 11 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,12 @@ void VideoDecoder::initializeFilterGraphForStream(
width = *options.width;
height = *options.height;
}
std::snprintf(description, sizeof(description), "scale=%d:%d:sws_flags=bilinear", width, height);
std::snprintf(
description,
sizeof(description),
"scale=%d:%d:sws_flags=bilinear",
width,
height);
AVFilterInOut* outputsTmp = outputs.release();
AVFilterInOut* inputsTmp = inputs.release();
ffmpegStatus = avfilter_graph_parse_ptr(
Expand Down Expand Up @@ -836,7 +841,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
tensor = tensor.permute({2, 0, 1});
}
output.frame = tensor;
} else if (streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH ) {
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
} else {
throw std::runtime_error(
Expand Down Expand Up @@ -958,7 +965,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
// in-place on the output tensor's data_ptr.
rawSingleOutput.data = output.frames[i].data_ptr<uint8_t>();
convertFrameToBufferUsingSwsScale(rawSingleOutput);
} else if (stream.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
} else if (
stream.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
// We are using a filter graph to convert the frame to tensor. The
// filter graph returns us an AVFrame allocated by FFMPEG. So we need to
// copy the AVFrame to the output tensor.
Expand Down
9 changes: 4 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ class VideoDecoder {
int streamIndex,
const VideoStreamDecoderOptions& options);
void maybeSeekToBeforeDesiredPts();
RawDecodedOutput getDecodedOutputWithFilter(std::function<bool(int, AVFrame*)>);
RawDecodedOutput getDecodedOutputWithFilter(
std::function<bool(int, AVFrame*)>);
RawDecodedOutput getNextRawDecodedOutputNoDemux();
// Once we create a decoder can update the metadata with the codec context.
// For example, for video streams, we can add the height and width of the
Expand All @@ -359,10 +360,8 @@ class VideoDecoder {
torch::Tensor convertFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* frame);
void convertFrameToBufferUsingSwsScale(
RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(
RawDecodedOutput& rawOutput);
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand Down
17 changes: 13 additions & 4 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include <cstdint>
#include <sstream>
#include <string>
#include "c10/util/Exception.h"
#include "c10/core/SymIntArrayRef.h"
#include "c10/util/Exception.h"
#include "src/torchcodec/decoders/_core/VideoDecoder.h"

namespace facebook::torchcodec {
Expand Down Expand Up @@ -121,7 +121,14 @@ void add_video_stream(
std::optional<int64_t> num_threads,
std::optional<c10::string_view> dimension_order,
std::optional<int64_t> stream_index) {
_add_video_stream(decoder, width, height, num_threads, dimension_order, stream_index, "filtergraph");
_add_video_stream(
decoder,
width,
height,
num_threads,
dimension_order,
stream_index,
"filtergraph");
}

void _add_video_stream(
Expand All @@ -145,9 +152,11 @@ void _add_video_stream(
if (color_conversion_library.has_value()) {
std::string stdColorConversionLibrary{color_conversion_library.value()};
if (stdColorConversionLibrary == "filtergraph") {
options.colorConversionLibrary = VideoDecoder::ColorConversionLibrary::FILTERGRAPH;
options.colorConversionLibrary =
VideoDecoder::ColorConversionLibrary::FILTERGRAPH;
} else if (stdColorConversionLibrary == "swscale") {
options.colorConversionLibrary = VideoDecoder::ColorConversionLibrary::SWSCALE;
options.colorConversionLibrary =
VideoDecoder::ColorConversionLibrary::SWSCALE;
} else {
throw std::runtime_error(
"Invalid color_conversion_library=" + stdColorConversionLibrary +
Expand Down
Loading