Skip to content

Commit

Permalink
Add HalfToFloat + FloatToHalf operators to PyTorch (pytorch#45092)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#45092

Adding two operators
1. at::float_to_half -> Converts FP32 tensor to FP16 tensor
2. at::half_to_float -> Converts FP16 tensor to FP32 tensor.

These operators internally use the kernel provided by FBGeMM. Both C2 and PT will use the same FBGeMM kernel underneath.

Test Plan:
buck test //caffe2/test:torch -- .*test_half_tensor.*

Run benchmark locally using

```
buck run //caffe2/benchmarks/operator_benchmark/pt:tensor_to_test
```

AI Bench results are pending. I expect that not to finish as we have large queue with jobs pending for 2+ days.

Benchmark for 512x512 tensor with FbGeMM implementation

```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : short

# Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark
# Mode: Eager
# Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 1246.332

# Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark
# Mode: Eager
# Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 1734.304
```

Benchmark for 512x512 tensor trunk with no FbGeMM integration.

```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : short

# Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark
# Mode: Eager
# Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 169045.724

# Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark
# Mode: Eager
# Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 152382.494
```

Reviewed By: ngimel

Differential Revision: D23824869

fbshipit-source-id: ef044459b6c8c6e5ddded72080204c6a0ab4582c
  • Loading branch information
Radhakrishnan Venkataramani authored and facebook-github-bot committed Nov 10, 2020
1 parent 497cd25 commit 163adb9
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
30 changes: 30 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
#include <ATen/NamedTensorUtils.h>
#include <torch/library.h>

#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
#include <fbgemm/FbgemmConvert.h>
#endif

namespace {

using namespace at;
Expand Down Expand Up @@ -94,6 +99,31 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
TORCH_CHECK(self.defined(), "self is undefined");
TORCH_CHECK(src.defined(), "src is undefined");

// FBGeMM kernel support exists only for the following case,
// 1. Memory Format for source and destination tensors is contiguous.
// 2. Device for both the source and destination tensor is CPU.
// 3. dtype conversion between FP32->FP16 and FP16->FP32.
#ifdef USE_FBGEMM
if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) ||
(self.dtype() == at::kHalf && src.dtype() == at::kFloat)) &&
(self.device().is_cpu() && src.device().is_cpu()) &&
!self.is_sparse() && !src.is_sparse() &&
((self.is_contiguous() && src.is_contiguous()) ||
(self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) {
if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
auto* output_ptr = reinterpret_cast<fbgemm::float16*>(
self.data_ptr<at::Half>());
fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
} else {
auto in_data = reinterpret_cast<fbgemm::float16*>(
src.data_ptr<at::Half>());
auto* output_ptr = self.data_ptr<float>();
fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
}
return self;
}
#endif

if (self.is_sparse() && src.is_sparse()) {
return at::copy_sparse_to_sparse_(self, src, non_blocking);
} else if (self.is_sparse() || src.is_sparse()) {
Expand Down
39 changes: 39 additions & 0 deletions benchmarks/operator_benchmark/pt/tensor_to_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import operator_benchmark as op_bench
import torch

tensor_conversion_short_configs = op_bench.cross_product_configs(
M=(8, 16, 32,),
N=(16, 64, 128,),
device=['cpu', 'cuda'],
tags=['short'],
)

tensor_conversion_long_configs = op_bench.cross_product_configs(
M=(64, 128, 256, 512,),
N=(256, 512, 1024, 2048,),
device=['cpu', 'cuda'],
tags=['long'],
)

class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, device):
self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.float)

def forward(self):
return self.input.to(torch.half)

class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, device):
self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.half)

def forward(self):
return self.input.to(torch.float)


op_bench.generate_pt_test(tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark)
op_bench.generate_pt_test(tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark)
op_bench.generate_pt_test(tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark)
op_bench.generate_pt_test(tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark)

if __name__ == "__main__":
op_bench.benchmark_runner.main()
45 changes: 31 additions & 14 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2951,20 +2951,34 @@ def test_parsing_intlist(self):
lambda: torch.tensor().new_zeros((5, 5), 0))

def test_half_tensor(self):
x = torch.randn(5, 5).float()
y = torch.randn(5, 5).float()
xh, yh = x.half(), y.half()

self.assertEqual(x.half().float(), x, atol=1e-3, rtol=0)
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")

z = torch.Tensor(5, 5)
self.assertEqual(z.copy_(xh), x, atol=1e-3, rtol=0)
# contiguous tensor
# non-contiguous tensor
# dense non-overlapping tensor
# non-dense non-overlapping sliced tensor
# non-dense overlapping equal strides
for device in devices:
tset = (
torch.randn(4, 3, 2, device=device, dtype=torch.float).contiguous(),
torch.randn(4, 3, 2, device=device, dtype=torch.float).transpose(0, 1),
torch.randn(4, 3, 2, device=device, dtype=torch.float),
torch.randn(4, 3, 2, device=device, dtype=torch.float)[:, :, ::2],
torch.empty_strided(
(4, 2, 3), (10, 3, 3), device=device, dtype=torch.float
).copy_(torch.rand((4, 2, 3), dtype=torch.float, device=device)),
)

with tempfile.NamedTemporaryFile() as f:
torch.save(xh, f)
f.seek(0)
xh2 = torch.load(f)
self.assertEqual(xh.float(), xh2.float())
for x in tset:
self.assertEqual(x.half().float(), x, atol=1e-3, rtol=0)
xh = x.half()
with tempfile.NamedTemporaryFile() as f:
torch.save(xh, f)
f.seek(0)
xh2 = torch.load(f)
self.assertEqual(xh.float(), xh2.float())

def test_from_buffer(self):
a = bytearray([1, 2, 3, 4])
Expand Down Expand Up @@ -17347,8 +17361,11 @@ def _test_copysign_numpy(a, b):
# Use double copysign to verify the correctnes of 0.0 and -0.0, since
# it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
# magnitude to verify the sign between torch and numpy results, elementwise.
self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result),
torch.copysign(torch.tensor(1.0), expected))
# Special case: NaN conversions between FP32 and FP16 is not bitwise
# equivalent to pass this assertion.
if a.dtype != torch.float16 and b.dtype != torch.float16:
self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result),
torch.copysign(torch.tensor(1.0), expected))

# Compare Result with NumPy
# Type promotion
Expand Down

0 comments on commit 163adb9

Please sign in to comment.