Skip to content

Commit

Permalink
Copy: Re-use existing neg and conj kernel implementations (pytorch#68949
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#68949

This reuses the existing `neg_kernel` and `conj_kernel`
implementations for copy, saving some binary size and compile time.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D33064390

Pulled By: anjali411

fbshipit-source-id: eb0ee94ed3db44ae828ea078ba616365f97a7ff5
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Jan 4, 2022
1 parent 95a1952 commit f8f96d4
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 105 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Copy.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>

namespace at {

class Tensor;
struct TensorIterator;

namespace native {
Expand Down
130 changes: 56 additions & 74 deletions aten/src/ATen/native/cpu/CopyKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include <ATen/core/op_registration/op_allowlist.h>
#include <ATen/ATen.h>

#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/native/Copy.h>
#include <ATen/native/TensorIterator.h>
Expand All @@ -10,86 +8,70 @@

namespace at {
namespace native {
inline namespace CPU_CAPABILITY {
void neg_kernel(TensorIteratorBase &iter);
void conj_kernel(TensorIteratorBase &iter);
} // namespace CPU_CAPABILITY

namespace {

static void copy_kernel(TensorIterator& iter, bool non_blocking) {
void direct_copy_kernel(TensorIteratorBase &iter) {
// TODO: we don't actually need separate instantiations per dtype;
// we only need a separate instantiation per dtype size. This would
// probably save us a little bit of code size here
// TODO: not sure if optimizer is able to compile two levels of
// conditionals into a single jump table. We should have a
// single jump table here; might be worth just writing out the
// dispatch statement by hand instead of using AT_DISPATCH
ScalarType dtype = iter.dtype(0);
if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
}
}

void neg_conj_kernel(TensorIteratorBase &iter) {
// fused a = b.neg().conj_physical()
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cpu", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -conj_impl(a); },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg().conj(); });
});
}

void copy_kernel(TensorIterator& iter, bool non_blocking) {
ScalarType dtype = iter.dtype(0);
if (dtype == iter.dtype(1)) {
// TODO: as the majority of these operations can be done treating
// their datatypes as opaque bit patterns, we don't actually need
// separate instantiations per dtype; we only need a separate
// instantiation per dtype size. This would probably save us a
// little bit of code size here
// TODO: not sure if optimizer is able to compile two levels of
// conditionals into a single jump table. We should have a
// single jump table here; might be worth just writing out the
// dispatch statement by hand instead of using AT_DISPATCH
if (iter.tensor(0).is_neg() == iter.tensor(1).is_neg()) {
if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else if (isComplexType(dtype)) {
// This case should never actually happen since currently there's no way to get a complex tensor
// with negative bit.
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.conj(); });
});
}
// This case should never actually happen since currently there's no way to get a complex tensor
// with negative bit.
if (isComplexType(dtype) &&
(iter.tensor(0).is_conj() != iter.tensor(1).is_conj())) {
conj_kernel(iter);
} else {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vectorized<scalar_t> a) { return a; });
});
direct_copy_kernel(iter);
}
} else {
if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return -a; });
} else if (isComplexType(dtype)) {
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg(); });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -1 * conj_impl(a); },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg().conj(); });
});
}
if (isComplexType(dtype) &&
(iter.tensor(0).is_conj() != iter.tensor(1).is_conj())) {
neg_conj_kernel(iter);
} else {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -a; },
[=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg(); });
});
neg_kernel(iter);
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ static void imag_kernel(TensorIteratorBase& iter) {
}

// NB: Ignores the negative bit on tensors
static void conj_kernel(TensorIteratorBase& iter) {
void conj_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
cpu_kernel_vec(
Expand Down
64 changes: 35 additions & 29 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <ATen/ATen.h>
#include <ATen/Functions.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CachingHostAllocator.h>
Expand All @@ -13,6 +13,31 @@
namespace at {
namespace native {

void neg_kernel_cuda(TensorIteratorBase &iter);
void conj_kernel_cuda(TensorIteratorBase &iter);

namespace {
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
}

void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cuda", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); });
});
}
} // namespace (anonymous)

using namespace at::cuda;

// device-to-device copy, does type conversion
Expand Down Expand Up @@ -64,36 +89,17 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
copy_stream));
}
} else {
auto dtype = iter.dtype(0);
if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
if (same_neg) {
if (!same_conj && same_type) {
conj_kernel_cuda(iter);
} else {
direct_copy_kernel_cuda(iter);
}
} else {
if (same_neg) {
if (!same_conj && same_type) {
AT_DISPATCH_COMPLEX_TYPES(
dtype, "copy_conj_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(x); });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
if (!same_conj && same_type) {
neg_conj_kernel_cuda(iter);
} else {
if (!same_conj && same_type) {
AT_DISPATCH_COMPLEX_TYPES(
dtype, "copy_conj_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return std::conj(-x); });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -x; });
});
}
neg_kernel_cuda(iter);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/util/Half.h>
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/cuda/TriangularOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
#include <ATen/NativeFunctions.h>
#include <ATen/native/Resize.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/diag.h>
#endif

#include <ATen/cuda/CUDAApplyUtils.cuh>

namespace at {
Expand Down

0 comments on commit f8f96d4

Please sign in to comment.