diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index 86df5801c73c..e3e0b283ac0a 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -31,6 +31,8 @@ #include "./sort_op.h" #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" +#include "./indexing_op.h" + namespace mxnet { namespace op { @@ -289,51 +291,6 @@ struct DotCsrTransDnsDnsWarpBlockKernel { } }; -/*! - * \brief GPU warp kernel of dot(csr.T, dns) = rsp - * Parallelization by columns: 1 warp computes one lhs column for one rhs column - */ -struct DotCsrTransDnsRspWarpKernel { - /*! - * \brief - * \param tid global thread id - * \param out output rsp matrix data - * \param row_flg_sum_out inclusive prefix sum array over 0/1 marked row flag array - * \param data_l csr matrix data - * \param indptr_l csr matrix row index pointer - * \param col_idx_l csr matrix column indices - * \param data_r dns matrix data - * \param num_cols_r dns matrix number of columns - */ - template - __device__ __forceinline__ static void Map(int tid, - DType* out, - const nnvm::dim_t* row_flg_sum_out, - const DType* data_l, - const IType* indptr_l, - const CType* col_idx_l, - const DType* data_r, - const nnvm::dim_t num_cols_r) { - using nnvm::dim_t; - const dim_t warp_id = tid / 32; // global warp id - const dim_t lane = tid & (32-1); // local thread id within warp - const dim_t icol = warp_id / num_cols_r; // lhs column that this warp computes - const dim_t kcol = warp_id % num_cols_r; // rhs column that this warp computes - - // Compute range of nnz elements in this column - const dim_t low = static_cast(indptr_l[icol]); - const dim_t high = static_cast(indptr_l[icol+1]); - - // Iterate through the nnz elements in this column - for (dim_t j = low+lane; j < high; j+=32) { - const dim_t irow = static_cast(col_idx_l[j]); - const dim_t rsp_row = row_flg_sum_out[irow]-1; - const DType val = data_l[j]*data_r[icol*num_cols_r+kcol]; - atomicAdd(static_cast(&(out[rsp_row*num_cols_r+kcol])), val); - } - } -}; - /*! * \brief GPU Kernel of dot(csr.T, rsp1) = rsp2 * Parallelization by rows: 1 thread/row @@ -671,8 +628,60 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx, }); } +struct DotCsrTransDnsRspKernel { + /*! + * \brief + * \param tid global thread id + * \param out output rsp matrix data + * \param lookup_table lookup table from row in lhs to row in dst + * \param sorted_indices csr matrix column indices in sorted order + * \param nnz number of non-zeros in csr matrix + * \param original_idx original indices to the unsorted csr column indices + * \param rhs dns rhs data + * \param val_array csr matrix data + * \param idx_array csr matrix row indices + * \param row_length length of a row in the output rsp matrix + */ + template + __device__ __forceinline__ static void Map(int thread_id, + DType* out, + const IType* lookup_table, + const IType* sorted_indices, + const nnvm::dim_t nnz, + const IType* original_idx, + const DType* rhs, + const DType* val_array, + const IType* idx_array, + const nnvm::dim_t row_length) { + int tid = thread_id / row_length; + const nnvm::dim_t offset = thread_id % row_length; + if (tid == 0 || sorted_indices[tid - 1] != sorted_indices[tid]) { + DType acc = 0; + const IType src_row_idx = sorted_indices[tid]; + const IType dst_row_idx = lookup_table[src_row_idx]; + const IType out_offset = dst_row_idx * row_length + offset; + do { + const IType idx = original_idx[tid]; + const DType val = val_array[idx]; + const DType col_idx = idx_array[idx]; + const IType rhs_offset = col_idx * row_length + offset; + acc += rhs[rhs_offset] * val; + tid++; + } while (tid < nnz && sorted_indices[tid - 1] == sorted_indices[tid]); + out[out_offset] = acc; + } + } +}; + +// Returns integer log2(a) rounded up +inline int log2i(size_t a) { + int k = 1; + while (a >>= 1) k++; + return k; +} + /*! - * \brief GPU Impl of dot(csr, dns) = rsp and dot(csr.T, dns) = rsp + * \brief GPU Impl of dot(csr.T, dns) = rsp */ inline void DotCsrDnsRspImpl(const OpContext& ctx, const gpu& gpu_dev, @@ -692,92 +701,114 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, } using mshadow::Shape1; + using mshadow::Tensor; using mxnet_op::Kernel; using mxnet_op::set_zero; using nnvm::dim_t; + using namespace csr; const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob indptr_l = lhs.aux_data(kIndPtr); + const TBlob col_idx_l = lhs.aux_data(kIdx); const TBlob& data_r = rhs; + size_t nnz = lhs.aux_data(kIdx).Size(); const dim_t num_rows_l = lhs.shape()[0]; const dim_t num_cols_l = lhs.shape()[1]; const dim_t num_cols_r = rhs.shape_[1]; - const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize; - dim_t num_threads; - // TODO: remove kernel dependency on warpSize=32 - if (threads_per_warp != 32) { - LOG(FATAL) << "DotCsrDnsRspImpl GPU kernels expect warpSize=32"; - } - + CHECK_EQ(ret->aux_type(rowsparse::kIdx), col_idx_l.type_flag_) + << "Mismatch indices dtype detected"; MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type if (trans_lhs) { - // Compute number of non-zero rows (nnr) of output matrix - // - alloc temp storage for row_flg array and for cub's prefix sum - // - mark non-zero columns of csr matrix in row_flg - // - compute inclusive prefix sum over marked array - // - copy last value (nnr_out) from device to host - dim_t* row_flg_out = NULL; - void* d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - row_flg_out, - row_flg_out, - num_cols_l, - mshadow::Stream::GetStream(s)); - mshadow::Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(num_cols_l * sizeof(dim_t) + - temp_storage_bytes), s); - row_flg_out = reinterpret_cast(workspace.dptr_); - d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t); - num_threads = num_cols_l; - Kernel::Launch(s, num_threads, row_flg_out); - num_threads = num_rows_l * threads_per_warp; - Kernel::Launch(s, num_threads, - row_flg_out, col_idx_l.dptr(), indptr_l.dptr(), - num_rows_l, num_cols_l); - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - row_flg_out, - row_flg_out, - num_cols_l, - mshadow::Stream::GetStream(s)); - dim_t nnr_out = 0; - CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t), - cudaMemcpyDeviceToHost)); - if (0 == nnr_out) { - FillZerosRspImpl(s, *ret); - return; - } + IType* col_idx_l_ptr = col_idx_l.dptr(); + // temporary memory layout + size_t* nnr_ptr = nullptr; + IType* original_idx_ptr = nullptr; + IType* row_idx_ptr = nullptr; + IType* col_idx_copy_ptr = nullptr; + IType* lookup_table_ptr = nullptr; + char* temp_storage_ptr = nullptr; - // Allocate output matrix space - ret->CheckAndAlloc({Shape1(nnr_out)}); - const TBlob data_out_blob = ret->data(); - const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx); - MSHADOW_IDX_TYPE_SWITCH(row_idx_out_blob.type_flag_, RType, { // row idx type - DType* data_out = data_out_blob.dptr(); - RType* row_idx_out = row_idx_out_blob.dptr(); - num_threads = nnr_out * num_cols_r; - Kernel::Launch(s, num_threads, data_out); - num_threads = nnr_out; - Kernel::Launch(s, num_threads, row_idx_out); + // estimate temp space for unique. + const size_t nnr_bytes = sizeof(size_t); + size_t unique_temp_bytes = 0; + size_t *null_ptr = nullptr; + size_t *null_dptr = nullptr; + cudaStream_t stream = mshadow::Stream::GetStream(s); + cub::DeviceSelect::Unique(NULL, unique_temp_bytes, null_dptr, null_dptr, + null_ptr, nnz, stream); + // the temp storage for sort and unique + size_t original_idx_bytes = nnz * sizeof(IType); + size_t row_idx_bytes = nnz * sizeof(IType); + size_t col_idx_copy_bytes = nnz * sizeof(IType); + size_t lookup_table_bytes = num_cols_l * sizeof(IType); + size_t sort_temp_bytes = SortByKeyWorkspaceSize(nnz); + size_t total_temp_bytes = std::max(sort_temp_bytes, unique_temp_bytes); + + // layout: original_idx, col_idx_copy, temp_storage + size_t total_workspace_bytes = nnr_bytes + original_idx_bytes + row_idx_bytes + + col_idx_copy_bytes + + lookup_table_bytes + total_temp_bytes; + // request temp space + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(total_workspace_bytes), s); + // update individual temp space ptrs + nnr_ptr = reinterpret_cast(workspace.dptr_); + original_idx_ptr = reinterpret_cast(workspace.dptr_ + nnr_bytes); + row_idx_ptr = reinterpret_cast(workspace.dptr_ + nnr_bytes + + original_idx_bytes); + col_idx_copy_ptr = reinterpret_cast(workspace.dptr_ + nnr_bytes + + original_idx_bytes + row_idx_bytes); + lookup_table_ptr = reinterpret_cast(workspace.dptr_ + nnr_bytes + + original_idx_bytes + row_idx_bytes + + col_idx_copy_bytes); + temp_storage_ptr = workspace.dptr_ + nnr_bytes + original_idx_bytes + + row_idx_bytes + col_idx_copy_bytes + lookup_table_bytes; - // Fill row_idx array of output matrix, using the row_flg values - num_threads = num_cols_l; - Kernel::Launch(s, num_threads, - row_idx_out, row_flg_out, num_cols_l); + // Fill original_idx + Kernel::Launch( + s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // Make a copy of col_idx_l + Kernel, gpu>::Launch( + s, nnz, col_idx_copy_ptr, col_idx_l_ptr); - // Perform matrix-matrix multiply - num_threads = threads_per_warp * num_rows_l * num_cols_r; - Kernel::Launch(s, num_threads, - data_out, row_flg_out, - data_l.dptr(), indptr_l.dptr(), col_idx_l.dptr(), - data_r.dptr(), num_cols_r); - }); + // Construct the tensors needed for SortByKey + Tensor col_idx_copy(col_idx_copy_ptr, Shape1(nnz), s); + Tensor original_idx(original_idx_ptr, Shape1(nnz), s); + Tensor temp_storage(temp_storage_ptr, Shape1(total_temp_bytes), s); + + int num_bits = log2i(num_cols_l - 1); + SortByKey(col_idx_copy, original_idx, true, &temp_storage, 0, num_bits); + + // over-allocate aux indices + ret->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnz)); + // compute unique indices + IType* ret_idx_ptr = ret->aux_data(rowsparse::kIdx).dptr(); + cub::DeviceSelect::Unique(temp_storage_ptr, unique_temp_bytes, col_idx_copy_ptr, ret_idx_ptr, + nnr_ptr, nnz, stream); + // retrieve num non-zero rows + size_t nnr = 0; + CUDA_CALL(cudaMemcpy(&nnr, nnr_ptr, nnr_bytes, cudaMemcpyDeviceToHost)); + // allocate data + ret->CheckAndAllocData(mshadow::Shape2(nnz, num_cols_r)); + // generate lookup table + Kernel::Launch(s, nnr, lookup_table_ptr, ret_idx_ptr); + + // Scatter csr indptr to row id + Kernel::Launch( + s, num_rows_l, indptr_l.dptr(), row_idx_ptr, num_rows_l); + + Kernel::Launch(s, nnz * num_cols_r, + ret->data().dptr(), + lookup_table_ptr, col_idx_copy_ptr, nnz, + original_idx_ptr, data_r.dptr(), + data_l.dptr(), + row_idx_ptr, num_cols_r); + + // reshape aux data + ret->set_aux_shape(rowsparse::kIdx, Shape1(nnr)); } else { LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns) = rsp yet."; } @@ -786,6 +817,7 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, }); } + /*! * \brief GPU Impl of dot(csr, rsp1) = rsp2 and dot(csr.T, rsp1) = rsp2 * TODO: Optimize for GPU; this is a baseline implementation providing @@ -990,12 +1022,6 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, }); } -// Returns integer log2(a) rounded up -inline int log2i(size_t a) { - int k = 1; - while (a >>= 1) k++; - return k; -} /* * \brief GPU Impl of dot(dns, csr) = csr diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 5cdf5060aec4..8d4fefb45c83 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -26,6 +26,7 @@ #include "./indexing_op.h" #include "./util/tensor_util-inl.cuh" +#include "./util/tensor_util-inl.h" namespace mxnet { namespace op { @@ -115,20 +116,6 @@ struct AddTakeGradRspDeterministicKernel { } }; -/* - * \brief the kernel to generate a lookup table for positions of row ids - * \param i thread id - * \param out output table - * \param data the input row id in sorted order - */ -struct mark_lookup_table { - template - MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) { - out[static_cast(data[i])] = i; - } -}; - - template<> void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, @@ -252,7 +239,7 @@ void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx, output.set_aux_shape(kIdx, Shape1(nnr)); // generate lookup table - Kernel::Launch(s, nnr, lookup_table, grad_row_idx); + Kernel::Launch(s, nnr, lookup_table, grad_row_idx); // accumulate gradients DType* grad_data = output.data().dptr(); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 0f6506640a48..ef0779b190ca 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -42,7 +42,6 @@ #include "./util/tensor_util-inl.h" #include "../mxnet_op.h" #include "./sort_op.h" -#include "./dot-inl.h" #include "./init_op.h" #include "./matrix_op-inl.h" #include "../../engine/openmp.h" diff --git a/src/operator/tensor/util/tensor_util-inl.h b/src/operator/tensor/util/tensor_util-inl.h index 45b12730318a..24602f22922e 100644 --- a/src/operator/tensor/util/tensor_util-inl.h +++ b/src/operator/tensor/util/tensor_util-inl.h @@ -76,6 +76,20 @@ struct FillRspRowIdxKernel { } }; +/* + * \brief the kernel to generate a lookup table for positions of row ids + * \param i thread id + * \param out output table + * \param data the input row id in sorted order + */ +struct MarkLookupTable { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) { + out[static_cast(data[i])] = i; + } +}; + + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 16b52f60ceb9..111c0b7b166c 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1336,10 +1336,9 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40) test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) - @with_seed() def test_sparse_dot_determinism(): - def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b): + def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b, forward_stype): lhs_row = rnd.randint(50, 100) lhs_col = rnd.randint(50, 100) if transpose_a: @@ -1352,18 +1351,17 @@ def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpo rhs_shape = (rnd.randint(50, 100), lhs_col) else: rhs_shape = (lhs_col, rnd.randint(50, 100)) - if default_context() == mx.cpu(): - forward_stype = 'csr' - else: - forward_stype = 'default' lhs_shape = (lhs_row, lhs_col) lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density) rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density) res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0) - test_dot_determinism('default', 'csr', 1.0, 0.1, False, False) - test_dot_determinism('default', 'csr', 1.0, 0.1, False, True) + + test_dot_determinism('csr', 'default', 0.1, 1.0, True, False, 'row_sparse') + forward_stype = 'csr' if default_context() == mx.cpu() else 'default' + test_dot_determinism('default', 'csr', 1.0, 0.1, False, False, forward_stype) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, True, forward_stype) @with_seed()