Skip to content

Commit 288b971

Browse files
ahmadsharif1facebook-github-bot
authored andcommitted
[torchcodec] Allow sampler to use GPU decoding (meta-pytorch#136)
Summary: Pull Request resolved: meta-pytorch#136 Reviewed By: scotts Differential Revision: D60402239
1 parent cdedfb4 commit 288b971

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class VideoTooShortException(Exception):
3131
@dataclass
3232
class DecoderArgs:
3333
num_threads: int = 0
34+
device: torch.device = torch.device("cpu")
3435

3536

3637
@dataclass
@@ -163,6 +164,7 @@ def forward(self, video_data: Tensor) -> Union[List[Any]]:
163164
width=target_width,
164165
height=target_height,
165166
num_threads=self.decoder_args.num_threads,
167+
device_string=str(self.decoder_args.device),
166168
)
167169

168170
clips: List[Any] = []

test/samplers/test_video_clip_sampler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66
from torchcodec._samplers import (
7+
DecoderArgs,
78
IndexBasedSamplerArgs,
89
TimeBasedSamplerArgs,
910
VideoArgs,
@@ -30,11 +31,16 @@
3031
),
3132
],
3233
)
33-
def test_sampler(sampler_args):
34+
@pytest.mark.parametrize(("device"), [torch.device("cpu"), torch.device("cuda:0")])
35+
def test_sampler(sampler_args, device):
36+
if device.type == "cuda" and not torch.cuda.is_available():
37+
pytest.skip("GPU not available")
38+
3439
torch.manual_seed(0)
3540
desired_width, desired_height = 320, 240
3641
video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height)
37-
sampler = VideoClipSampler(video_args, sampler_args)
42+
decoder_args = DecoderArgs(device=device)
43+
sampler = VideoClipSampler(video_args, sampler_args, decoder_args)
3844
clips = sampler(NASA_VIDEO.to_tensor())
3945
assert_tensor_equal(len(clips), sampler_args.clips_per_video)
4046
clip = clips[0]

0 commit comments

Comments
 (0)