Skip to content

Commit cddabb2

Browse files
ssnlsoumith
authored andcommitted
Use cuda device object instead of .cuda (pytorch#231)
1 parent 68a1f59 commit cddabb2

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

advanced_source/cpp_extension.rst

+9-7
Original file line numberDiff line numberDiff line change
@@ -519,23 +519,25 @@ can *also* run on GPU, and individual operations will correspondingly dispatch
519519
to GPU-optimized implementations. For certain operations like matrix multiply
520520
(like ``mm`` or ``admm``), this is a big win. Let's take a look at how much
521521
performance we gain from running our C++ code with CUDA tensors. No changes to
522-
our implementation are required, we simply need to move our tensors to GPU
523-
memory with ``.cuda()`` from Python::
522+
our implementation are required, we simply need to put our tensors in GPU
523+
memory from Python, with either adding ``device=cuda_device`` argument at
524+
creation time or using ``.to(cuda_device)`` after creation::
524525

525526
import torch
526527

527528
assert torch.cuda.is_available()
529+
cuda_device = torch.device("cuda") # device object representing GPU
528530

529531
batch_size = 16
530532
input_features = 32
531533
state_size = 128
532534

533-
# Note the .cuda() calls here
534-
X = torch.randn(batch_size, input_features).cuda()
535-
h = torch.randn(batch_size, state_size).cuda()
536-
C = torch.randn(batch_size, state_size).cuda()
535+
# Note the device=cuda_device arguments here
536+
X = torch.randn(batch_size, input_features, device=cuda_device)
537+
h = torch.randn(batch_size, state_size, device=cuda_device)
538+
C = torch.randn(batch_size, state_size, device=cuda_device)
537539

538-
rnn = LLTM(input_features, state_size).cuda()
540+
rnn = LLTM(input_features, state_size).to(cuda_device)
539541

540542
forward = 0
541543
backward = 0

0 commit comments

Comments
 (0)