Skip to content

Commit

Permalink
Support overwriting PTS in StreamWriter (pytorch#3135)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#3135

Differential Revision: D43724273

Pulled By: mthrok

fbshipit-source-id: f89f3d15a065fe5b3a5ef150e34089e8cbcbc948
  • Loading branch information
mthrok authored and facebook-github-bot committed Mar 6, 2023
1 parent 1c2d182 commit 8277823
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 12 deletions.
32 changes: 32 additions & 0 deletions test/torchaudio_unittest/io/stream_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,35 @@ def test_audio_pts_increment(self):
num_samples += chunk.size(0)
print(chunk.pts, expected)
assert abs(chunk.pts - expected) < 1e-10

def test_video_pts_overwrite(self):
"""Can overwrite PTS"""

ext = "mp4"
num_frames = 256
filename = f"test.{ext}"
frame_rate = 10
width, height = 96, 128

# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)

video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
reference_pts = [2 * i / frame_rate for i in range(num_frames)]
with writer.open():
for i, pts in enumerate(reference_pts):
writer.write_video_chunk(0, video[i : i + 1], pts)

# check
if self.test_fileobj:
dst.flush()

reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == len(reference_pts)

for val, ref in zip(pts, reference_pts):
assert val == ref
9 changes: 8 additions & 1 deletion torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,10 @@ EncodeProcess::EncodeProcess(
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}

void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
void EncodeProcess::process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
TORCH_CHECK(
codec_ctx->codec_type == type,
"Attempted to write ",
Expand All @@ -521,6 +524,10 @@ void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
" stream.");

AVRational codec_tb = codec_ctx->time_base;
if (pts) {
src_frame->pts =
static_cast<int64_t>(pts.value() * codec_tb.den / codec_tb.num);
}
for (const auto& frame : converter.convert(tensor)) {
process_frame(frame);
if (type == AVMEDIA_TYPE_VIDEO) {
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/csrc/ffmpeg/stream_writer/encode_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class EncodeProcess {
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel);

void process(AVMediaType type, const torch::Tensor& tensor);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);

void process_frame(AVFrame* src);

Expand Down
14 changes: 10 additions & 4 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,30 @@ void StreamWriter::close() {
}
}

void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
void StreamWriter::write_audio_chunk(
int i,
const torch::Tensor& waveform,
const c10::optional<double>& pts) {
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
}

void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
void StreamWriter::write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts) {
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
}

void StreamWriter::flush() {
Expand Down
10 changes: 8 additions & 2 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,20 @@ class StreamWriter {
/// @param i Stream index.
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
void write_audio_chunk(int i, const torch::Tensor& chunk);
void write_audio_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
/// Write video data
/// @param i Stream index.
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
/// width and the number of channels)`` must match what was configured when
/// calling ``add_video_stream()``.
void write_video_chunk(int i, const torch::Tensor& chunk);
void write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
/// Flush the frames from encoders and write the frames to the destination.
void flush();
};
Expand Down
21 changes: 17 additions & 4 deletions torchaudio/io/_stream_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,23 @@ def close(self):
self._s.close()
self._is_open = False

def write_audio_chunk(self, i: int, chunk: torch.Tensor):
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write audio data
Args:
i (int): Stream index.
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
.. note::
The value of pts is converted to integer value expressed in basis of
sample rate. Therefore, it is truncated to the nearest value of
``n / sample_rate``.
"""
self._s.write_audio_chunk(i, chunk)
self._s.write_audio_chunk(i, chunk, pts)

def write_video_chunk(self, i: int, chunk: torch.Tensor):
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write video/image data
Args:
Expand All @@ -295,8 +301,15 @@ def write_video_chunk(self, i: int, chunk: torch.Tensor):
The ``dtype`` must be ``torch.uint8``.
The shape (height, width and the number of channels) must match
what was configured when calling :py:meth:`add_video_stream`
pts (float, optional or None): If provided, overwrite the presentation timestamp.
.. note::
The value of pts is converted to integer value expressed in basis of
frame rate. Therefore, it is truncated to the nearest value of
``n / frame_rate``.
"""
self._s.write_video_chunk(i, chunk)
self._s.write_video_chunk(i, chunk, pts)

def flush(self):
"""Flush the frames from encoders and write the frames to the destination."""
Expand Down

0 comments on commit 8277823

Please sign in to comment.