forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HalfToFloat + FloatToHalf operators to PyTorch (pytorch#45092)
Summary: Pull Request resolved: pytorch#45092 Adding two operators 1. at::float_to_half -> Converts FP32 tensor to FP16 tensor 2. at::half_to_float -> Converts FP16 tensor to FP32 tensor. These operators internally use the kernel provided by FBGeMM. Both C2 and PT will use the same FBGeMM kernel underneath. Test Plan: buck test //caffe2/test:torch -- .*test_half_tensor.* Run benchmark locally using ``` buck run //caffe2/benchmarks/operator_benchmark/pt:tensor_to_test ``` AI Bench results are pending. I expect that not to finish as we have large queue with jobs pending for 2+ days. Benchmark for 512x512 tensor with FbGeMM implementation ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark # Mode: Eager # Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu # Input: M: 512, N: 512, device: cpu Forward Execution Time (us) : 1246.332 # Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark # Mode: Eager # Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu # Input: M: 512, N: 512, device: cpu Forward Execution Time (us) : 1734.304 ``` Benchmark for 512x512 tensor trunk with no FbGeMM integration. ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark # Mode: Eager # Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu # Input: M: 512, N: 512, device: cpu Forward Execution Time (us) : 169045.724 # Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark # Mode: Eager # Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu # Input: M: 512, N: 512, device: cpu Forward Execution Time (us) : 152382.494 ``` Reviewed By: ngimel Differential Revision: D23824869 fbshipit-source-id: ef044459b6c8c6e5ddded72080204c6a0ab4582c
- Loading branch information
1 parent
497cd25
commit 163adb9
Showing
3 changed files
with
100 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import operator_benchmark as op_bench | ||
import torch | ||
|
||
tensor_conversion_short_configs = op_bench.cross_product_configs( | ||
M=(8, 16, 32,), | ||
N=(16, 64, 128,), | ||
device=['cpu', 'cuda'], | ||
tags=['short'], | ||
) | ||
|
||
tensor_conversion_long_configs = op_bench.cross_product_configs( | ||
M=(64, 128, 256, 512,), | ||
N=(256, 512, 1024, 2048,), | ||
device=['cpu', 'cuda'], | ||
tags=['long'], | ||
) | ||
|
||
class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase): | ||
def init(self, M, N, device): | ||
self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.float) | ||
|
||
def forward(self): | ||
return self.input.to(torch.half) | ||
|
||
class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase): | ||
def init(self, M, N, device): | ||
self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.half) | ||
|
||
def forward(self): | ||
return self.input.to(torch.float) | ||
|
||
|
||
op_bench.generate_pt_test(tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark) | ||
op_bench.generate_pt_test(tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark) | ||
op_bench.generate_pt_test(tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark) | ||
op_bench.generate_pt_test(tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark) | ||
|
||
if __name__ == "__main__": | ||
op_bench.benchmark_runner.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters