Skip to content

Commit

Permalink
[ATen/CPU] Parallelize HalfToFloat + FloatToHalf operators in PT (pyt…
Browse files Browse the repository at this point in the history
…orch#47777)

Summary:
Pull Request resolved: pytorch#47777

Parallelize FP32 <-> FP16 op.
- Use at::Parallelize in ATen instead of parallelizing inside FBGEMM;
- provide more flexibility (at::Parallelize can be configured with different parallel backend).
ghstack-source-id: 116499687

Test Plan:
```
OMP_NUM_THREADS=10 buck test //caffe2/test:torch -- .test_half_tensor.
```
https://our.intern.facebook.com/intern/testinfra/testrun/7036874441928985

```
OMP_NUM_THREADS=10 buck run mode/opt -c pytorch.parallel_backend=tbb //caffe2/benchmarks/operator_benchmark/pt:tensor_to_test -- --iterations 1 --omp_num_threads 10 --warmup_iterations 0
```

Benchmark results for 512 x 512 Tensor copy:

- With 1 thread:
```
(base) [jianyuhuang@devbig281.ftw3.facebook.com: ~/fbsource/fbcode/caffe2/caffe2/operators] $ buck run mode/opt -c py
torch.parallel_backend=tbb //caffe2/benchmarks/operator_benchmark/pt:tensor_to_test -- --iterations 1 --omp_num_thread
s 1 --warmup_iterations 10
Parsing buck files: finished in 1.3 sec                                                                               Building: finished in 5.7 sec (100%) 6087/6087 jobs, 0 updated
  Total time: 7.0 sec
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : short

# Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark
# Mode: Eager
# Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 99.279

# Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark
# Mode: Eager
# Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 81.707
```

- With 2 threads:
```
(base) [jianyuhuang@devbig281.ftw3.facebook.com: ~/fbsource/fbcode/caffe2/caffe2/operators] $ buck run mode/opt -c py
torch.parallel_backend=tbb //caffe2/benchmarks/operator_benchmark/pt:tensor_to_test -- --iterations 1 --omp_num_thread
s 2 --warmup_iterations 10
Parsing buck files: finished in 1.3 sec
Building: finished in 4.4 sec (100%) 6087/6087 jobs, 0 updated
  Total time: 5.7 sec
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : short

# Benchmarking PyTorch: FloatToHalfTensorConversionBenchmark
# Mode: Eager
# Name: FloatToHalfTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 68.162

# Benchmarking PyTorch: HalfToFloatTensorConversionBenchmark
# Mode: Eager
# Name: HalfToFloatTensorConversionBenchmark_M512_N512_cpu
# Input: M: 512, N: 512, device: cpu
Forward Execution Time (us) : 49.245
```

Reviewed By: ngimel

Differential Revision: D24676355

fbshipit-source-id: 02bfb893a7b5a60f97c0559d8974c53837755ac2
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 15, 2020
1 parent f824854 commit 0e98fdd
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/metal/Context.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#ifdef USE_FBGEMM
Expand Down Expand Up @@ -111,14 +112,30 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
((self.is_contiguous() && src.is_contiguous()) ||
(self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) {
if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
auto* output_ptr = reinterpret_cast<fbgemm::float16*>(
self.data_ptr<at::Half>());
fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
auto* output_ptr =
reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
at::parallel_for(
0,
self.numel(),
at::internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
fbgemm::FloatToFloat16_simd(
src.data_ptr<float>() + begin,
output_ptr + begin,
end - begin);
});
} else {
auto in_data = reinterpret_cast<fbgemm::float16*>(
src.data_ptr<at::Half>());
auto* output_ptr = self.data_ptr<float>();
fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
at::parallel_for(
0,
self.numel(),
at::internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
fbgemm::Float16ToFloat_simd(
in_data + begin, output_ptr + begin, end - begin);
});
}
return self;
}
Expand Down

0 comments on commit 0e98fdd

Please sign in to comment.