From 4be2876d434d01e2a18a86a70c49fd0b172e1733 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 3 Mar 2023 12:36:28 -0800 Subject: [PATCH] Support overwriting PTS in StreamWriter (#3135) Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3135 Differential Revision: D43724273 Pulled By: mthrok fbshipit-source-id: 72c3579220a94df383ebd16f92762565aa66124f --- .../io/stream_writer_test.py | 32 +++++++++++++++++++ torchaudio/csrc/ffmpeg/pybind/pybind.cpp | 20 +++++++++--- .../stream_writer/audio_output_stream.cpp | 15 +++++---- .../stream_writer/audio_output_stream.h | 3 +- .../csrc/ffmpeg/stream_writer/output_stream.h | 2 +- .../ffmpeg/stream_writer/stream_writer.cpp | 18 +++++++++-- .../csrc/ffmpeg/stream_writer/stream_writer.h | 6 ++-- .../stream_writer/video_output_stream.cpp | 13 +++++--- .../stream_writer/video_output_stream.h | 2 +- torchaudio/io/_stream_writer.py | 21 +++++++++--- 10 files changed, 107 insertions(+), 25 deletions(-) diff --git a/test/torchaudio_unittest/io/stream_writer_test.py b/test/torchaudio_unittest/io/stream_writer_test.py index da600c4324c..b1ff5916b35 100644 --- a/test/torchaudio_unittest/io/stream_writer_test.py +++ b/test/torchaudio_unittest/io/stream_writer_test.py @@ -419,3 +419,35 @@ def test_audio_pts_increment(self): num_samples += chunk.size(0) print(chunk.pts, expected) assert abs(chunk.pts - expected) < 1e-10 + + def test_video_pts_overwrite(self): + """Can overwrite PTS""" + + ext = "mp4" + num_frames = 256 + filename = f"test.{ext}" + frame_rate = 10 + width, height = 96, 128 + + # Write data + dst = self.get_dst(filename) + writer = torchaudio.io.StreamWriter(dst=dst, format=ext) + writer.add_video_stream(frame_rate=frame_rate, width=width, height=height) + + video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8) + reference_pts = [2 * i / frame_rate for i in range(num_frames)] + with writer.open(): + for i, pts in enumerate(reference_pts): + writer.write_video_chunk(0, video[i:i+1], pts) + + # check + if self.test_fileobj: + dst.flush() + + reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) + reader.add_video_stream(1) + pts = [chunk.pts for (chunk,) in reader.stream()] + assert len(pts) == len(reference_pts) + + for val, ref in zip(pts, reference_pts): + assert val == ref diff --git a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp index 75fa33cab65..d120634beca 100644 --- a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp +++ b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp @@ -40,8 +40,14 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { .def("add_video_stream", &StreamWriter::add_video_stream) .def("dump_format", &StreamWriter::dump_format) .def("open", &StreamWriter::open) - .def("write_audio_chunk", &StreamWriter::write_audio_chunk) - .def("write_video_chunk", &StreamWriter::write_video_chunk) + .def( + "write_audio_chunk", + py::overload_cast&>( + &StreamWriter::write_audio_chunk)) + .def( + "write_video_chunk", + py::overload_cast&>( + &StreamWriter::write_video_chunk)) .def("flush", &StreamWriter::flush) .def("close", &StreamWriter::close); py::class_(m, "StreamWriterFileObj", py::module_local()) @@ -51,8 +57,14 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { .def("add_video_stream", &StreamWriterFileObj::add_video_stream) .def("dump_format", &StreamWriterFileObj::dump_format) .def("open", &StreamWriterFileObj::open) - .def("write_audio_chunk", &StreamWriterFileObj::write_audio_chunk) - .def("write_video_chunk", &StreamWriterFileObj::write_video_chunk) + .def( + "write_audio_chunk", + py::overload_cast&>( + &StreamWriterFileObj::write_audio_chunk)) + .def( + "write_video_chunk", + py::overload_cast&>( + &StreamWriterFileObj::write_video_chunk)) .def("flush", &StreamWriterFileObj::flush) .def("close", &StreamWriterFileObj::close); py::class_(m, "OutputStreamInfo", py::module_local()) diff --git a/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.cpp b/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.cpp index a17ecb0feb5..7a1bb80414f 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.cpp @@ -65,12 +65,15 @@ AudioOutputStream::AudioOutputStream( converter(buffer, buffer->nb_samples), codec_ctx(std::move(codec_ctx_)) {} -void AudioOutputStream::write_chunk(const torch::Tensor& waveform) { - AVRational time_base{1, codec_ctx->sample_rate}; - for (const auto& frame : converter.convert(waveform)) { - process_frame(frame); - frame->pts += - av_rescale_q(frame->nb_samples, time_base, codec_ctx->time_base); +void AudioOutputStream::write_chunk(const torch::Tensor& waveform, const c10::optional& pts) { + AVRational sr_tb{1, codec_ctx->sample_rate}; + AVRational codec_tb = codec_ctx->time_base; + if (pts) { + buffer->pts = static_cast(pts.value() * codec_tb.den / codec_tb.num); + } + for (const auto& buffer_ : converter.convert(waveform)) { + process_frame(buffer_); + buffer_->pts += av_rescale_q(buffer_->nb_samples, sr_tb, codec_tb); } } diff --git a/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h b/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h index 1813d35ea04..c3377233754 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h +++ b/torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h @@ -14,7 +14,8 @@ struct AudioOutputStream : OutputStream { AVSampleFormat src_fmt, AVCodecContextPtr&& codec_ctx); - void write_chunk(const torch::Tensor& waveform) override; + void write_chunk(const torch::Tensor& frames, const c10::optional& pts = {}) override; + ~AudioOutputStream() override = default; }; diff --git a/torchaudio/csrc/ffmpeg/stream_writer/output_stream.h b/torchaudio/csrc/ffmpeg/stream_writer/output_stream.h index 3495810584e..b76e847eae2 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/output_stream.h +++ b/torchaudio/csrc/ffmpeg/stream_writer/output_stream.h @@ -22,7 +22,7 @@ struct OutputStream { AVCodecContext* codec_ctx, FilterGraph&& filter); - virtual void write_chunk(const torch::Tensor& input) = 0; + virtual void write_chunk(const torch::Tensor& frames, const c10::optional& pts = {}) = 0; void process_frame(AVFrame* src); void flush(); virtual ~OutputStream() = default; diff --git a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp index aed3275f783..5369ca42e5b 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp @@ -583,13 +583,27 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) { } void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) { + write_audio_chunk(i, waveform, -1); +} + +void StreamWriter::write_audio_chunk( + int i, + const torch::Tensor& waveform, + const c10::optional& pts) { validate_stream(i, AVMEDIA_TYPE_AUDIO); - streams[i]->write_chunk(waveform); + streams[i]->write_chunk(waveform, pts); } void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) { + write_video_chunk(i, frames, -1); +} + +void StreamWriter::write_video_chunk( + int i, + const torch::Tensor& frames, + const c10::optional& pts) { validate_stream(i, AVMEDIA_TYPE_VIDEO); - streams[i]->write_chunk(frames); + streams[i]->write_chunk(frames, pts); } void StreamWriter::flush() { diff --git a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h index 9e493e44df7..6c9e4e74deb 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h +++ b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h @@ -161,14 +161,16 @@ class StreamWriter { /// @param i Stream index. /// @param chunk Waveform tensor. Shape: ``(frame, channel)``. /// The ``dtype`` must match what was passed to ``add_audio_stream()`` method. - void write_audio_chunk(int i, const torch::Tensor& chunk); + void write_audio_chunk(int i, const torch::Tensor& frames); + void write_audio_chunk(int i, const torch::Tensor& frames, const c10::optional& pts); /// Write video data /// @param i Stream index. /// @param chunk Video/image tensor. Shape: ``(time, channel, height, /// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height, /// width and the number of channels)`` must match what was configured when /// calling ``add_video_stream()``. - void write_video_chunk(int i, const torch::Tensor& chunk); + void write_video_chunk(int i, const torch::Tensor& frames); + void write_video_chunk(int i, const torch::Tensor& frames, const c10::optional& pts); /// Flush the frames from encoders and write the frames to the destination. void flush(); diff --git a/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.cpp b/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.cpp index 49c0fe44d1a..50210e0301b 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.cpp @@ -68,10 +68,15 @@ VideoOutputStream::VideoOutputStream( converter(buffer), codec_ctx(std::move(codec_ctx_)) {} -void VideoOutputStream::write_chunk(const torch::Tensor& frames) { - for (const auto& frame : converter.convert(frames)) { - process_frame(frame); - frame->pts += 1; +void VideoOutputStream::write_chunk(const torch::Tensor& frames, const c10::optional& pts) { + AVRational codec_tb = codec_ctx->time_base; + if (pts) { + double val = pts.value(); + buffer->pts = static_cast(val * codec_tb.den / codec_tb.num); + } + for (const auto& buffer_ : converter.convert(frames)) { + process_frame(buffer_); + buffer_->pts += 1; } } diff --git a/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h b/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h index 45730251906..26265e84ddf 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h +++ b/torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h @@ -14,7 +14,7 @@ struct VideoOutputStream : OutputStream { AVPixelFormat src_fmt, AVCodecContextPtr&& codec_ctx); - void write_chunk(const torch::Tensor& frames) override; + void write_chunk(const torch::Tensor& frames, const c10::optional& pts = {}) override; ~VideoOutputStream() override = default; }; diff --git a/torchaudio/io/_stream_writer.py b/torchaudio/io/_stream_writer.py index eb644cb9110..dea48c123e1 100644 --- a/torchaudio/io/_stream_writer.py +++ b/torchaudio/io/_stream_writer.py @@ -275,17 +275,23 @@ def close(self): self._s.close() self._is_open = False - def write_audio_chunk(self, i: int, chunk: torch.Tensor): + def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None): """Write audio data Args: i (int): Stream index. chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`. The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method. + pts (float, optional, or None): If provided, overwrite the presentation timestamp. + .. note:: + + The value of pts is converted to integer value expressed in basis of + sample rate. Therefore, it is truncated to the nearest value of + ``n / sample_rate``. """ - self._s.write_audio_chunk(i, chunk) + self._s.write_audio_chunk(i, chunk, pts) - def write_video_chunk(self, i: int, chunk: torch.Tensor): + def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None): """Write video/image data Args: @@ -295,8 +301,15 @@ def write_video_chunk(self, i: int, chunk: torch.Tensor): The ``dtype`` must be ``torch.uint8``. The shape (height, width and the number of channels) must match what was configured when calling :py:meth:`add_video_stream` + pts (float, optional or None): If provided, overwrite the presentation timestamp. + + .. note:: + + The value of pts is converted to integer value expressed in basis of + frame rate. Therefore, it is truncated to the nearest value of + ``n / frame_rate``. """ - self._s.write_video_chunk(i, chunk) + self._s.write_video_chunk(i, chunk, pts) def flush(self): """Flush the frames from encoders and write the frames to the destination."""