Skip to content

Commit

Permalink
Fix type promotion for ldexp (pytorch#133519)
Browse files Browse the repository at this point in the history
According to the documentation, ldexp of half and int should return half tensor and ldexp of double should not overflow for 64-bit exponent

Introduce `_pow2` helper function that does not follow scalar to float32 promotion pattern if `self` is reduced precision float or double

Add regression tests to `test_ldexp` and enable it to run on both CPU and GPU

Fixes pytorch#133267

Pull Request resolved: pytorch#133519
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
  • Loading branch information
malfet authored and pytorchmergebot committed Aug 16, 2024
1 parent 3a904d1 commit 1653f77
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
15 changes: 13 additions & 2 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1542,12 +1542,23 @@ TORCH_IMPL_FUNC(heaviside_out) (
heaviside_stub(device_type(), *this);
}

static inline Tensor _pow2(const Tensor& self, const Tensor& other) {
const auto self_dtype = self.scalar_type();
// All integral types are promoted to float32
if (isIntegralType(self_dtype, true) || self_dtype == kFloat) {
return at::pow(2.0, other);
}
// For double and reduced floating types do regular type promotion
return at::full({}, 2.0, self.options()).pow(other);
}

Tensor& ldexp_out(const Tensor& self, const Tensor& other, Tensor& result) {
return at::mul_out(result, self, at::pow(2.0, other));
return at::mul_out(result, self, _pow2(self, other));
}


Tensor ldexp(const Tensor& self, const Tensor& other) {
return at::mul(self, at::pow(2.0, other));
return at::mul(self, _pow2(self, other));
}

Tensor& ldexp_(Tensor& self, const Tensor& other) {
Expand Down
27 changes: 19 additions & 8 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def _helper_reference_numerics(
numpy_sample = sample.numpy()
l_numpy = numpy_sample.input
r_numpy = numpy_sample.args[0]

actual = op(l, r)
expected = op.ref(l_numpy, r_numpy)

Expand Down Expand Up @@ -3407,29 +3406,41 @@ def test_rpow(self, device):
assert m.dim() == 0, "m is intentionally a scalar"
self.assertEqual(torch.pow(2, m), 2**m)

@onlyCPU
def test_ldexp(self, device):
# random values
mantissas = torch.randn(64, device=device)
exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)

# basic test
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
pt_outcome_1 = torch.ldexp(mantissas, exponents)
pt_outcome_2 = mantissas.ldexp(exponents)
self.assertEqual(np_outcome, pt_outcome_1)
self.assertEqual(np_outcome, pt_outcome_2)
self.assertEqual(np_outcome, pt_outcome_1.cpu())
self.assertEqual(np_outcome, pt_outcome_2.cpu())
mantissas.ldexp_(exponents)
self.assertEqual(np_outcome, mantissas)
self.assertEqual(np_outcome, mantissas.cpu())

# test bounds
mantissas = torch.tensor(
[float("inf"), float("-inf"), float("inf"), float("nan")], device=device
)
exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
pt_outcome = torch.ldexp(mantissas, exponents)
self.assertEqual(np_outcome, pt_outcome)
self.assertEqual(np_outcome, pt_outcome.cpu())

# test half dtype behavior
mantissas = torch.randn(64, device=device, dtype=torch.half)
exponents = torch.randint(-5, 5, (64,), device=device)
self.assertEqual(torch.ldexp(mantissas, exponents).dtype, torch.half)

# test float64 computation
mantissas = torch.tensor([1], dtype=torch.float64, device=device)
exponents = torch.tensor([128], dtype=torch.int64, device=device)
expected = torch.pow(
torch.full((1,), 2, device=device, dtype=torch.float64), 128
)
self.assertEqual(torch.ldexp(mantissas, exponents), expected)

@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_lerp(self, device, dtype):
Expand Down

0 comments on commit 1653f77

Please sign in to comment.