Skip to content

Commit 96c8fdd

Browse files
committed
Fix/enable sparse tests on ROCm
fix enabling sparse tests fp16/bf16 for rocm7.0/7.1
1 parent f1ad49a commit 96c8fdd

File tree

7 files changed

+92
-42
lines changed

7 files changed

+92
-42
lines changed

aten/src/ATen/cuda/CUDASparseDescriptors.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,7 @@ cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch
7575
auto leading_dimension =
7676
is_row_major ? input_strides[ndim - 2] : input_strides[ndim - 1];
7777

78-
#if !defined(USE_ROCM)
7978
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
80-
#else
81-
TORCH_INTERNAL_ASSERT(is_column_major, "Expected column major input.");
82-
auto order = CUSPARSE_ORDER_COL;
83-
#endif
8479

8580
auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0;
8681
// NOLINTNEXTLINE(*const-cast)

aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ void spmm(
615615

616616
// CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
617617
// silently returning incorrect results
618-
#if defined(USE_ROCM)
618+
#if defined(USE_ROCM) && (ROCM_VERSION < 60300)
619619
auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
620620
mat1.crow_indices().to(kInt),
621621
mat1.col_indices().to(kInt),

aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ const char* cusparseGetErrorString(cusparseStatus_t status) {
6363
case CUSPARSE_STATUS_ZERO_PIVOT:
6464
return "an entry of the matrix is either structural zero or numerical zero (singular block)";
6565

66+
#if defined(USE_ROCM)
67+
case CUSPARSE_STATUS_NOT_SUPPORTED:
68+
return "operation is not supported";
69+
70+
case CUSPARSE_STATUS_INSUFFICIENT_RESOURCES:
71+
return "Resources are insufficient";
72+
#endif // defined(USE_ROCM)
73+
6674
default:
6775
return "unknown error";
6876
}

aten/src/ATen/native/sparse/cuda/SparseMatMul.cu

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,24 @@
4040
#include <thrust/iterator/discard_iterator.h>
4141

4242

43-
#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000)
43+
#if defined(__CUDACC__) && ((CUSPARSE_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
4444
#define IS_CUSPARSE11_AVAILABLE() 1
4545
#else
4646
#define IS_CUSPARSE11_AVAILABLE() 0
4747
#endif
4848

49+
#if defined(USE_ROCM) && (ROCM_VERSION >= 70000)
50+
#define HIPSPARSE_FP16_SUPPORT 1
51+
#else
52+
#define HIPSPARSE_FP16_SUPPORT 0
53+
#endif
54+
55+
#if defined(USE_ROCM) && (ROCM_VERSION >= 70100)
56+
#define HIPSPARSE_FP16_BF16_SUPPORT 1
57+
#else
58+
#define HIPSPARSE_FP16_BF16_SUPPORT 0
59+
#endif
60+
4961
#if IS_CUSPARSE11_AVAILABLE()
5062
#include <library_types.h>
5163
#endif
@@ -207,13 +219,24 @@ struct CusparseMatrixMultiplyOp {
207219

208220
CusparseMatrixMultiplyOp() {
209221
static_assert(
210-
std::is_same_v<c10::Half, scalar_t> ||
211-
std::is_same_v<c10::BFloat16, scalar_t> ||
222+
#if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT
223+
std::is_same_v<c10::Half, scalar_t> ||
224+
#endif
225+
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
226+
std::is_same_v<c10::BFloat16, scalar_t> ||
227+
#endif
212228
std::is_same_v<float, scalar_t> ||
213229
std::is_same_v<double, scalar_t> ||
214230
std::is_same_v<c10::complex<float>, scalar_t> ||
215231
std::is_same_v<c10::complex<double>, scalar_t>,
216-
"cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
232+
"cusparseSpGEMM only supports data type of "
233+
#if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT
234+
"half, "
235+
#endif
236+
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
237+
"bfloat16, "
238+
#endif
239+
"float, double and complex float, double.");
217240
// SpGEMM Computation
218241
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
219242
}
@@ -268,11 +291,13 @@ struct CusparseMatrixMultiplyOp {
268291

269292
// If a specific GPU model does not provide native support for a given data type,
270293
// the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
294+
#if !defined(USE_ROCM)
271295
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
272296
TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
273297
"sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
274298
TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
275299
"sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
300+
#endif
276301

277302
// ask bufferSize1 bytes for external memory
278303
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
@@ -811,10 +836,20 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
811836
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
812837

813838
#if IS_CUSPARSE11_AVAILABLE()
814-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
815-
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
816-
});
817-
#else
839+
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
840+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
841+
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
842+
});
843+
#elif HIPSPARSE_FP16_SUPPORT
844+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, mat1_.scalar_type(), "sparse_matmul", [&] {
845+
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
846+
});
847+
#else
848+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
849+
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
850+
});
851+
#endif
852+
#else // not IS_CUSPARSE11_AVAILABLE()
818853
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
819854
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
820855
});

test/test_sparse.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
import unittest
1010
from torch.testing import make_tensor
11-
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
11+
from torch.testing._internal.common_utils import TestCase, run_tests, do_test_dtypes, \
1212
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
1313
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
1414
parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
@@ -68,6 +68,12 @@ def _op_supports_any_sparse(op):
6868
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
6969

7070
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
71+
HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0")
72+
HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1")
73+
74+
SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
75+
SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED)
76+
SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED)
7177

7278
def all_sparse_layouts(test_name='layout', include_strided=False):
7379
return parametrize(test_name, [
@@ -3608,13 +3614,12 @@ def test_log_softmax_zero_nnz(self, device, dtype):
36083614
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
36093615

36103616
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
3611-
@skipIfRocm
36123617
@coalescedonoff
36133618
@dtypes(*floating_and_complex_types())
3614-
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [],
3615-
*[torch.bfloat16] if SM80OrLater else [],
3619+
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
3620+
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
36163621
torch.complex64,
3617-
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
3622+
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
36183623
@unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
36193624
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
36203625
def test_sparse_matmul(self, device, dtype, coalesced):

test/test_sparse_csr.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
1313
from torch.testing._internal.common_utils import \
1414
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
15-
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
16-
suppress_warnings)
15+
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm,
16+
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
1717
from torch.testing._internal.common_device_type import \
1818
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
1919
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
@@ -26,7 +26,8 @@
2626
all_types_and_complex, floating_and_complex_types_and)
2727
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
2828
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
29-
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
29+
from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \
30+
SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED
3031
import operator
3132

3233
if TEST_SCIPY:
@@ -1545,9 +1546,10 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device
15451546
run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
15461547

15471548
@onlyCUDA
1548-
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1549+
@skipIfRocmVersionLessThan((6, 3))
15491550
@skipCUDAIfNoSparseGeneric
1550-
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1551+
@dtypes(*floating_and_complex_types_and(*[torch.half] if HIPSPARSE_FP16_SUPPORTED else [],
1552+
*[torch.bfloat16] if HIPSPARSE_BF16_SUPPORTED else []))
15511553
def test_bmm(self, device, dtype):
15521554
def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
15531555
b = b.mH if (op_b and a.shape == b.shape) else b
@@ -1834,7 +1836,7 @@ def run_test(a, b, upper, transpose, unitriangular, op_out):
18341836
run_test(a, b, upper, unitriangular, transpose, op_out)
18351837

18361838
@skipCPUIfNoMklSparse
1837-
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1839+
@skipIfRocmVersionLessThan((6, 3))
18381840
@dtypes(torch.double)
18391841
def test_mm(self, device, dtype):
18401842
def test_shape(di, dj, dk, nnz0=None, nnz1=None):
@@ -1954,8 +1956,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
19541956

19551957
@dtypes(*floating_and_complex_types())
19561958
@dtypesIfCUDA(*floating_and_complex_types_and(
1957-
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1958-
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1959+
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
1960+
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else []))
19591961
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
19601962
def test_sparse_addmm(self, device, dtype):
19611963
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
@@ -1984,18 +1986,15 @@ def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
19841986
test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))
19851987

19861988
@skipCPUIfNoMklSparse
1989+
@skipIfRocmVersionLessThan((6, 3))
19871990
@dtypes(*floating_and_complex_types())
19881991
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
19891992
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
19901993
@dtypesIfCUDA(*floating_types_and(torch.complex64,
1991-
*[torch.bfloat16] if SM80OrLater else [],
1992-
*[torch.half] if SM53OrLater else [],
1993-
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
1994+
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
1995+
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
1996+
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
19941997
@sparse_compressed_nonblock_layouts()
1995-
@skipCUDAIf(
1996-
not _check_cusparse_spgemm_available(),
1997-
"cuSparse Generic API SpGEMM is not available"
1998-
)
19991998
def test_addmm_all_sparse_csr(self, device, dtype, layout):
20001999
M = torch.randn(10, 25, device=device).to(dtype)
20012000
m1 = torch.randn(10, 50, device=device).to(dtype)
@@ -2066,16 +2065,12 @@ def maybe_transpose(cond, m):
20662065
@skipCPUIfNoMklSparse
20672066
@dtypes(*floating_and_complex_types())
20682067
@dtypesIfCUDA(*floating_types_and(torch.complex64,
2069-
*[torch.bfloat16] if SM80OrLater else [],
2070-
*[torch.half] if SM53OrLater else [],
2071-
*[torch.complex128]
2072-
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
2073-
else []))
2068+
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
2069+
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
2070+
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
20742071
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
20752072
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
20762073
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
2077-
if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0):
2078-
self.skipTest("Skipped on ROCm")
20792074
M = torch.randn(n, m, device=device).to(dtype)
20802075
m1 = torch.randn(n, k, device=device).to(dtype)
20812076
m2 = torch.randn(k, m, device=device).to(dtype)

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8175,11 +8175,15 @@
81758175
("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)),
81768176
("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81778177
("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8178-
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8178+
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8179+
("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81798180
("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81808181
("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81818182
("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81828183
("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8184+
("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8185+
("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)),
8186+
("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81838187
("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81848188
("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
81858189
("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
@@ -8228,6 +8232,14 @@
82288232
"CUSPARSE_STATUS_ZERO_PIVOT",
82298233
("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL),
82308234
),
8235+
(
8236+
"CUSPARSE_STATUS_NOT_SUPPORTED",
8237+
("HIPSPARSE_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPECIAL),
8238+
),
8239+
(
8240+
"CUSPARSE_STATUS_INSUFFICIENT_RESOURCES",
8241+
("HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES", CONV_NUMERIC_LITERAL, API_SPECIAL),
8242+
),
82318243
(
82328244
"CUSPARSE_OPERATION_TRANSPOSE",
82338245
("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL),

0 commit comments

Comments
 (0)