Skip to content

Commit

Permalink
Fix half_tensor.bernoulli_(double) (pytorch#13474)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#12431
Pull Request resolved: pytorch#13474

Differential Revision: D12897834

Pulled By: SsnL

fbshipit-source-id: 598250fd7b9f1d2509ec0e5012724d7895a62daf
  • Loading branch information
ssnl authored and facebook-github-bot committed Nov 2, 2018
1 parent 61a2d47 commit 2f82a06
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {

Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_scalar_cuda_", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_scalar_cuda_", [&] {
auto seeds = next_philox_seed(gen, 10);
bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
});
Expand Down
11 changes: 9 additions & 2 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,8 +1244,15 @@ def test_cat_empty(self):
_TestTorchMixin._test_cat_empty(self, use_cuda=True)

def test_bernoulli(self):
_TestTorchMixin._test_bernoulli(self, torch.double, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.half, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.float32, torch.float64, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.float32, torch.float16, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.float16, torch.float64, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.float16, torch.float16, 'cuda')
# test that it works with integral tensors
_TestTorchMixin._test_bernoulli(self, torch.uint8, torch.float64, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.uint8, torch.float16, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.int64, torch.float64, 'cuda')
_TestTorchMixin._test_bernoulli(self, torch.int64, torch.float16, 'cuda')

def test_cat_bad_input_sizes(self):
x = torch.randn(2, 1).cuda()
Expand Down
9 changes: 5 additions & 4 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7489,7 +7489,7 @@ def test_norm_fastpaths(self):
self.assertEqual(result, expected)

@staticmethod
def _test_bernoulli(self, p_dtype, device):
def _test_bernoulli(self, t_dtype, p_dtype, device):
for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):
x = torch.tensor(trivial_p, dtype=p_dtype, device=device)
self.assertEqual(x.bernoulli().tolist(), trivial_p)
Expand All @@ -7511,8 +7511,7 @@ def isBinary(t):
torch.bernoulli(torch.rand_like(p), out=p)
self.assertTrue(isBinary(p))

# test that it works with integral tensors
t = torch.empty(10, 10, dtype=torch.uint8, device=device)
t = torch.empty(10, 10, dtype=t_dtype, device=device)

t.fill_(2)
t.bernoulli_(0.5)
Expand All @@ -7532,7 +7531,9 @@ def isBinary(t):
self.assertTrue(isBinary(t))

def test_bernoulli(self):
self._test_bernoulli(self, torch.double, 'cpu')
self._test_bernoulli(self, torch.float32, torch.float64, 'cpu')
# test that it works with integral tensors
self._test_bernoulli(self, torch.uint8, torch.float64, 'cpu')

def test_normal(self):
q = torch.Tensor(100, 100)
Expand Down

0 comments on commit 2f82a06

Please sign in to comment.