Skip to content

Commit

Permalink
Support overwriting PTS in StreamWriter (#3135)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3135

Differential Revision: D43724273

Pulled By: mthrok

fbshipit-source-id: 72c3579220a94df383ebd16f92762565aa66124f
  • Loading branch information
mthrok committed Mar 4, 2023
1 parent db4898f commit 4be2876
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 25 deletions.
32 changes: 32 additions & 0 deletions test/torchaudio_unittest/io/stream_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,35 @@ def test_audio_pts_increment(self):
num_samples += chunk.size(0)
print(chunk.pts, expected)
assert abs(chunk.pts - expected) < 1e-10

def test_video_pts_overwrite(self):
"""Can overwrite PTS"""

ext = "mp4"
num_frames = 256
filename = f"test.{ext}"
frame_rate = 10
width, height = 96, 128

# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)

video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
reference_pts = [2 * i / frame_rate for i in range(num_frames)]
with writer.open():
for i, pts in enumerate(reference_pts):
writer.write_video_chunk(0, video[i:i+1], pts)

# check
if self.test_fileobj:
dst.flush()

reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == len(reference_pts)

for val, ref in zip(pts, reference_pts):
assert val == ref
20 changes: 16 additions & 4 deletions torchaudio/csrc/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("add_video_stream", &StreamWriter::add_video_stream)
.def("dump_format", &StreamWriter::dump_format)
.def("open", &StreamWriter::open)
.def("write_audio_chunk", &StreamWriter::write_audio_chunk)
.def("write_video_chunk", &StreamWriter::write_video_chunk)
.def(
"write_audio_chunk",
py::overload_cast<int, const torch::Tensor&, const c10::optional<double>&>(
&StreamWriter::write_audio_chunk))
.def(
"write_video_chunk",
py::overload_cast<int, const torch::Tensor&, const c10::optional<double>&>(
&StreamWriter::write_video_chunk))
.def("flush", &StreamWriter::flush)
.def("close", &StreamWriter::close);
py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj", py::module_local())
Expand All @@ -51,8 +57,14 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("add_video_stream", &StreamWriterFileObj::add_video_stream)
.def("dump_format", &StreamWriterFileObj::dump_format)
.def("open", &StreamWriterFileObj::open)
.def("write_audio_chunk", &StreamWriterFileObj::write_audio_chunk)
.def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
.def(
"write_audio_chunk",
py::overload_cast<int, const torch::Tensor&, const c10::optional<double>&>(
&StreamWriterFileObj::write_audio_chunk))
.def(
"write_video_chunk",
py::overload_cast<int, const torch::Tensor&, const c10::optional<double>&>(
&StreamWriterFileObj::write_video_chunk))
.def("flush", &StreamWriterFileObj::flush)
.def("close", &StreamWriterFileObj::close);
py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
Expand Down
15 changes: 9 additions & 6 deletions torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ AudioOutputStream::AudioOutputStream(
converter(buffer, buffer->nb_samples),
codec_ctx(std::move(codec_ctx_)) {}

void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
AVRational time_base{1, codec_ctx->sample_rate};
for (const auto& frame : converter.convert(waveform)) {
process_frame(frame);
frame->pts +=
av_rescale_q(frame->nb_samples, time_base, codec_ctx->time_base);
void AudioOutputStream::write_chunk(const torch::Tensor& waveform, const c10::optional<double>& pts) {
AVRational sr_tb{1, codec_ctx->sample_rate};
AVRational codec_tb = codec_ctx->time_base;
if (pts) {
buffer->pts = static_cast<int64_t>(pts.value() * codec_tb.den / codec_tb.num);
}
for (const auto& buffer_ : converter.convert(waveform)) {
process_frame(buffer_);
buffer_->pts += av_rescale_q(buffer_->nb_samples, sr_tb, codec_tb);
}
}

Expand Down
3 changes: 2 additions & 1 deletion torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ struct AudioOutputStream : OutputStream {
AVSampleFormat src_fmt,
AVCodecContextPtr&& codec_ctx);

void write_chunk(const torch::Tensor& waveform) override;
void write_chunk(const torch::Tensor& frames, const c10::optional<double>& pts = {}) override;

~AudioOutputStream() override = default;
};

Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/ffmpeg/stream_writer/output_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct OutputStream {
AVCodecContext* codec_ctx,
FilterGraph&& filter);

virtual void write_chunk(const torch::Tensor& input) = 0;
virtual void write_chunk(const torch::Tensor& frames, const c10::optional<double>& pts = {}) = 0;
void process_frame(AVFrame* src);
void flush();
virtual ~OutputStream() = default;
Expand Down
18 changes: 16 additions & 2 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,13 +583,27 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) {
}

void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
write_audio_chunk(i, waveform, -1);
}

void StreamWriter::write_audio_chunk(
int i,
const torch::Tensor& waveform,
const c10::optional<double>& pts) {
validate_stream(i, AVMEDIA_TYPE_AUDIO);
streams[i]->write_chunk(waveform);
streams[i]->write_chunk(waveform, pts);
}

void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
write_video_chunk(i, frames, -1);
}

void StreamWriter::write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts) {
validate_stream(i, AVMEDIA_TYPE_VIDEO);
streams[i]->write_chunk(frames);
streams[i]->write_chunk(frames, pts);
}

void StreamWriter::flush() {
Expand Down
6 changes: 4 additions & 2 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,16 @@ class StreamWriter {
/// @param i Stream index.
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
void write_audio_chunk(int i, const torch::Tensor& chunk);
void write_audio_chunk(int i, const torch::Tensor& frames);
void write_audio_chunk(int i, const torch::Tensor& frames, const c10::optional<double>& pts);
/// Write video data
/// @param i Stream index.
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
/// width and the number of channels)`` must match what was configured when
/// calling ``add_video_stream()``.
void write_video_chunk(int i, const torch::Tensor& chunk);
void write_video_chunk(int i, const torch::Tensor& frames);
void write_video_chunk(int i, const torch::Tensor& frames, const c10::optional<double>& pts);
/// Flush the frames from encoders and write the frames to the destination.
void flush();

Expand Down
13 changes: 9 additions & 4 deletions torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,15 @@ VideoOutputStream::VideoOutputStream(
converter(buffer),
codec_ctx(std::move(codec_ctx_)) {}

void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
for (const auto& frame : converter.convert(frames)) {
process_frame(frame);
frame->pts += 1;
void VideoOutputStream::write_chunk(const torch::Tensor& frames, const c10::optional<double>& pts) {
AVRational codec_tb = codec_ctx->time_base;
if (pts) {
double val = pts.value();
buffer->pts = static_cast<int64_t>(val * codec_tb.den / codec_tb.num);
}
for (const auto& buffer_ : converter.convert(frames)) {
process_frame(buffer_);
buffer_->pts += 1;
}
}

Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct VideoOutputStream : OutputStream {
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx);

void write_chunk(const torch::Tensor& frames) override;
void write_chunk(const torch::Tensor& frames, const c10::optional<double>& pts = {}) override;

~VideoOutputStream() override = default;
};
Expand Down
21 changes: 17 additions & 4 deletions torchaudio/io/_stream_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,23 @@ def close(self):
self._s.close()
self._is_open = False

def write_audio_chunk(self, i: int, chunk: torch.Tensor):
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write audio data
Args:
i (int): Stream index.
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
.. note::
The value of pts is converted to integer value expressed in basis of
sample rate. Therefore, it is truncated to the nearest value of
``n / sample_rate``.
"""
self._s.write_audio_chunk(i, chunk)
self._s.write_audio_chunk(i, chunk, pts)

def write_video_chunk(self, i: int, chunk: torch.Tensor):
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write video/image data
Args:
Expand All @@ -295,8 +301,15 @@ def write_video_chunk(self, i: int, chunk: torch.Tensor):
The ``dtype`` must be ``torch.uint8``.
The shape (height, width and the number of channels) must match
what was configured when calling :py:meth:`add_video_stream`
pts (float, optional or None): If provided, overwrite the presentation timestamp.
.. note::
The value of pts is converted to integer value expressed in basis of
frame rate. Therefore, it is truncated to the nearest value of
``n / frame_rate``.
"""
self._s.write_video_chunk(i, chunk)
self._s.write_video_chunk(i, chunk, pts)

def flush(self):
"""Flush the frames from encoders and write the frames to the destination."""
Expand Down

0 comments on commit 4be2876

Please sign in to comment.