Skip to content

Commit

Permalink
Revert "remove fp16 support from cpu linalg functions"
Browse files Browse the repository at this point in the history
This reverts commit de18c28.

Reverted pytorch#75647 on behalf of https://github.com/suo
  • Loading branch information
pytorchmergebot committed Apr 13, 2022
1 parent de18c28 commit 495c5ae
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Tensor dot(const Tensor &self, const Tensor &other){
return r;
}

return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
Tensor result = at::empty({}, self.options());
result.fill_(dot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
return result;
Expand Down
21 changes: 18 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ static void addmm_impl_cpu_(
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());

// Apply BLAS routine
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
at::native::cpublas::gemm(
Expand Down Expand Up @@ -1428,6 +1428,20 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
// is_bmm_out: true for bmm_out, false for baddbmm_
// self_or_result is "self" for baddbmm_ and "result" for bmm_out
Tensor& self_or_result = const_cast<Tensor&>(self_or_result_);
CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm");

auto checkOnCPU = [](const Tensor& t, CheckedFrom c) {
TORCH_CHECK(
!t.is_cuda(),
"Expect tensor to have CPU backend, but got tensor with ",
toString(t.options().backend()),
" Backend (while checking arguments for ",
c);
};

checkOnCPU(self_or_result, c);
checkOnCPU(batch1, c);
checkOnCPU(batch2, c);

const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
Expand Down Expand Up @@ -1464,15 +1478,16 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens

if (contraction_size * res_rows * res_cols < 400) {
if (is_bmm_out) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "bmm", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] {
baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
});
}
} else if (at::hasMKL() && ((
self_or_result.scalar_type() != kHalf &&
self_or_result.scalar_type() != kBFloat16 &&
at::native::is_floating_point(self_or_result)) ||
at::native::is_complex(self_or_result))
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/test/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ void TestOnesAndDot(DeprecatedTypeProperties& type) {
Tensor b = ones({3, 4}, type);
ASSERT_EQ_RESOLVED((b + b).sum().item<double>(), 24);
ASSERT_EQ_RESOLVED(b.numel(), 12);
if (type.backend() != Backend::CPU || type.scalarType() != kHalf) {
ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
}
ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
}

void TestSort(DeprecatedTypeProperties& type) {
Expand Down
9 changes: 5 additions & 4 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5831,7 +5831,7 @@ def maybe_transpose(cond, m):
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_and_complex_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) else []))
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm(self, device, dtype):
self._test_addmm_impl(torch.addmm, None, device, dtype)
Expand Down Expand Up @@ -6043,8 +6043,9 @@ def test_strided_mm_bmm(self, device, dtype):
self.compare_with_numpy(torch_fn, np_fn, sx[0])

@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
@tf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
Expand Down Expand Up @@ -6156,7 +6157,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):

@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
Expand Down Expand Up @@ -6229,7 +6230,7 @@ def generate_tensor():

@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
@tf32_on_and_off(0.05)
def test_baddbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
Expand Down
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def unsupported(dtype):
# NOTE: some ops will fail in forward if their inputs
# require grad but they don't support computing the gradient
# in that type! This is a bug in the op!
print("dtype", dtype, e)
unsupported(dtype)

# Short-circuits testing this dtype -- it doesn't work
Expand Down
28 changes: 14 additions & 14 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8695,7 +8695,7 @@ def ref_pairwise_distance(input1, input2):
# This addmm OpInfo is for when alpha and beta are not both equal to 1.
# alpha=beta=1 is tested in the following opinfo, because that special case will
# trigger addmm being decomposed by a jit pass.
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
assert_autodiffed=True,
Expand All @@ -8707,7 +8707,7 @@ def ref_pairwise_distance(input1, input2):
OpInfo('addmm',
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
variant_test_name='decomposed',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
Expand Down Expand Up @@ -8736,7 +8736,7 @@ def ref_pairwise_distance(input1, input2):
ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
np.multiply(np.asarray(alpha, dtype=batch1.dtype),
np.sum(np.matmul(batch1, batch2), axis=0))),
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
Expand All @@ -8759,7 +8759,7 @@ def ref_pairwise_distance(input1, input2):
),
sample_inputs_func=sample_inputs_addbmm),
OpInfo('baddbmm',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if CUDA11OrLater or TEST_WITH_ROCM else []),
backward_dtypesIfCUDA=floating_types_and(torch.float16,
Expand All @@ -8776,7 +8776,7 @@ def ref_pairwise_distance(input1, input2):
'TestMathBits', 'test_conj_view', device_type='cuda')],
sample_inputs_func=sample_inputs_baddbmm),
OpInfo('dot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
Expand All @@ -8785,15 +8785,15 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
),
OpInfo('vdot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_dot_vdot,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
OpInfo('bmm',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM)else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
Expand Down Expand Up @@ -9330,7 +9330,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
)),
OpInfo('cov',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
Expand Down Expand Up @@ -10312,7 +10312,7 @@ def ref_pairwise_distance(input1, input2):
OpInfo('linalg.multi_dot',
# Need this lambda because gradcheck does not work with TensorList inputs
aten_name='linalg_multi_dot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
supports_inplace_autograd=False,
Expand Down Expand Up @@ -11842,7 +11842,7 @@ def ref_pairwise_distance(input1, input2):
aten_name='linear',
supports_autograd=True,
sample_inputs_func=sample_inputs_linear,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
Expand Down Expand Up @@ -12387,7 +12387,7 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::relu6"]),
OpInfo('mm',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
Expand Down Expand Up @@ -13431,7 +13431,7 @@ def ref_pairwise_distance(input1, input2):
# we need this lambda because SampleInput expects tensor input as the first argument
# TODO(@heitorschueroff) update SampleInput to handle such cases
op=lambda tensors, equation: torch.einsum(equation, tensors),
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
Expand Down Expand Up @@ -14867,7 +14867,7 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_kron),
OpInfo('inner',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
Expand All @@ -14876,7 +14876,7 @@ def ref_pairwise_distance(input1, input2):
sample_inputs_func=sample_inputs_inner,
),
OpInfo('tensordot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
Expand Down

0 comments on commit 495c5ae

Please sign in to comment.