Skip to content

Commit

Permalink
[MXNET-363] fix race condition in sparse dot(csr.T, dense) on gpu (ap…
Browse files Browse the repository at this point in the history
…ache#10713)

* compiles

* fix bug

* remove unused code

* refactor

* update test

* change dim_t to IType

* remove unused header

* remove extra headers
  • Loading branch information
eric-haibin-lin authored Apr 28, 2018
1 parent 82c33f6 commit 97daed8
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 145 deletions.
268 changes: 147 additions & 121 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "./sort_op.h"
#include "./util/tensor_util-inl.h"
#include "./util/tensor_util-inl.cuh"
#include "./indexing_op.h"


namespace mxnet {
namespace op {
Expand Down Expand Up @@ -289,51 +291,6 @@ struct DotCsrTransDnsDnsWarpBlockKernel {
}
};

/*!
* \brief GPU warp kernel of dot(csr.T, dns) = rsp
* Parallelization by columns: 1 warp computes one lhs column for one rhs column
*/
struct DotCsrTransDnsRspWarpKernel {
/*!
* \brief
* \param tid global thread id
* \param out output rsp matrix data
* \param row_flg_sum_out inclusive prefix sum array over 0/1 marked row flag array
* \param data_l csr matrix data
* \param indptr_l csr matrix row index pointer
* \param col_idx_l csr matrix column indices
* \param data_r dns matrix data
* \param num_cols_r dns matrix number of columns
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* out,
const nnvm::dim_t* row_flg_sum_out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t icol = warp_id / num_cols_r; // lhs column that this warp computes
const dim_t kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Compute range of nnz elements in this column
const dim_t low = static_cast<dim_t>(indptr_l[icol]);
const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);

// Iterate through the nnz elements in this column
for (dim_t j = low+lane; j < high; j+=32) {
const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
const dim_t rsp_row = row_flg_sum_out[irow]-1;
const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
atomicAdd(static_cast<DType *>(&(out[rsp_row*num_cols_r+kcol])), val);
}
}
};

/*!
* \brief GPU Kernel of dot(csr.T, rsp1) = rsp2
* Parallelization by rows: 1 thread/row
Expand Down Expand Up @@ -671,8 +628,60 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
});
}

struct DotCsrTransDnsRspKernel {
/*!
* \brief
* \param tid global thread id
* \param out output rsp matrix data
* \param lookup_table lookup table from row in lhs to row in dst
* \param sorted_indices csr matrix column indices in sorted order
* \param nnz number of non-zeros in csr matrix
* \param original_idx original indices to the unsorted csr column indices
* \param rhs dns rhs data
* \param val_array csr matrix data
* \param idx_array csr matrix row indices
* \param row_length length of a row in the output rsp matrix
*/
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int thread_id,
DType* out,
const IType* lookup_table,
const IType* sorted_indices,
const nnvm::dim_t nnz,
const IType* original_idx,
const DType* rhs,
const DType* val_array,
const IType* idx_array,
const nnvm::dim_t row_length) {
int tid = thread_id / row_length;
const nnvm::dim_t offset = thread_id % row_length;
if (tid == 0 || sorted_indices[tid - 1] != sorted_indices[tid]) {
DType acc = 0;
const IType src_row_idx = sorted_indices[tid];
const IType dst_row_idx = lookup_table[src_row_idx];
const IType out_offset = dst_row_idx * row_length + offset;
do {
const IType idx = original_idx[tid];
const DType val = val_array[idx];
const DType col_idx = idx_array[idx];
const IType rhs_offset = col_idx * row_length + offset;
acc += rhs[rhs_offset] * val;
tid++;
} while (tid < nnz && sorted_indices[tid - 1] == sorted_indices[tid]);
out[out_offset] = acc;
}
}
};

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
return k;
}

/*!
* \brief GPU Impl of dot(csr, dns) = rsp and dot(csr.T, dns) = rsp
* \brief GPU Impl of dot(csr.T, dns) = rsp
*/
inline void DotCsrDnsRspImpl(const OpContext& ctx,
const gpu& gpu_dev,
Expand All @@ -692,92 +701,114 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
}

using mshadow::Shape1;
using mshadow::Tensor;
using mxnet_op::Kernel;
using mxnet_op::set_zero;
using nnvm::dim_t;
using namespace csr;

const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob indptr_l = lhs.aux_data(kIndPtr);
const TBlob col_idx_l = lhs.aux_data(kIdx);
const TBlob& data_r = rhs;
size_t nnz = lhs.aux_data(kIdx).Size();

const dim_t num_rows_l = lhs.shape()[0];
const dim_t num_cols_l = lhs.shape()[1];
const dim_t num_cols_r = rhs.shape_[1];
const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
dim_t num_threads;
// TODO: remove kernel dependency on warpSize=32
if (threads_per_warp != 32) {
LOG(FATAL) << "DotCsrDnsRspImpl GPU kernels expect warpSize=32";
}

CHECK_EQ(ret->aux_type(rowsparse::kIdx), col_idx_l.type_flag_)
<< "Mismatch indices dtype detected";
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (trans_lhs) {
// Compute number of non-zero rows (nnr) of output matrix
// - alloc temp storage for row_flg array and for cub's prefix sum
// - mark non-zero columns of csr matrix in row_flg
// - compute inclusive prefix sum over marked array
// - copy last value (nnr_out) from device to host
dim_t* row_flg_out = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
row_flg_out,
row_flg_out,
num_cols_l,
mshadow::Stream<gpu>::GetStream(s));
mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(num_cols_l * sizeof(dim_t) +
temp_storage_bytes), s);
row_flg_out = reinterpret_cast<dim_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t);
num_threads = num_cols_l;
Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_out);
num_threads = num_rows_l * threads_per_warp;
Kernel<MarkCsrColWarpKernel, gpu>::Launch(s, num_threads,
row_flg_out, col_idx_l.dptr<CType>(), indptr_l.dptr<IType>(),
num_rows_l, num_cols_l);
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
row_flg_out,
row_flg_out,
num_cols_l,
mshadow::Stream<gpu>::GetStream(s));
dim_t nnr_out = 0;
CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
cudaMemcpyDeviceToHost));
if (0 == nnr_out) {
FillZerosRspImpl(s, *ret);
return;
}
IType* col_idx_l_ptr = col_idx_l.dptr<IType>();
// temporary memory layout
size_t* nnr_ptr = nullptr;
IType* original_idx_ptr = nullptr;
IType* row_idx_ptr = nullptr;
IType* col_idx_copy_ptr = nullptr;
IType* lookup_table_ptr = nullptr;
char* temp_storage_ptr = nullptr;

// Allocate output matrix space
ret->CheckAndAlloc({Shape1(nnr_out)});
const TBlob data_out_blob = ret->data();
const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx);
MSHADOW_IDX_TYPE_SWITCH(row_idx_out_blob.type_flag_, RType, { // row idx type
DType* data_out = data_out_blob.dptr<DType>();
RType* row_idx_out = row_idx_out_blob.dptr<RType>();
num_threads = nnr_out * num_cols_r;
Kernel<set_zero, gpu>::Launch(s, num_threads, data_out);
num_threads = nnr_out;
Kernel<set_zero, gpu>::Launch(s, num_threads, row_idx_out);
// estimate temp space for unique.
const size_t nnr_bytes = sizeof(size_t);
size_t unique_temp_bytes = 0;
size_t *null_ptr = nullptr;
size_t *null_dptr = nullptr;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
cub::DeviceSelect::Unique(NULL, unique_temp_bytes, null_dptr, null_dptr,
null_ptr, nnz, stream);
// the temp storage for sort and unique
size_t original_idx_bytes = nnz * sizeof(IType);
size_t row_idx_bytes = nnz * sizeof(IType);
size_t col_idx_copy_bytes = nnz * sizeof(IType);
size_t lookup_table_bytes = num_cols_l * sizeof(IType);
size_t sort_temp_bytes = SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);
size_t total_temp_bytes = std::max(sort_temp_bytes, unique_temp_bytes);

// layout: original_idx, col_idx_copy, temp_storage
size_t total_workspace_bytes = nnr_bytes + original_idx_bytes + row_idx_bytes +
col_idx_copy_bytes +
lookup_table_bytes + total_temp_bytes;
// request temp space
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
// update individual temp space ptrs
nnr_ptr = reinterpret_cast<size_t*>(workspace.dptr_);
original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes);
row_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
original_idx_bytes);
col_idx_copy_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
original_idx_bytes + row_idx_bytes);
lookup_table_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
original_idx_bytes + row_idx_bytes +
col_idx_copy_bytes);
temp_storage_ptr = workspace.dptr_ + nnr_bytes + original_idx_bytes +
row_idx_bytes + col_idx_copy_bytes + lookup_table_bytes;

// Fill row_idx array of output matrix, using the row_flg values
num_threads = num_cols_l;
Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_threads,
row_idx_out, row_flg_out, num_cols_l);
// Fill original_idx
Kernel<range_fwd, gpu>::Launch(
s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// Make a copy of col_idx_l
Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
s, nnz, col_idx_copy_ptr, col_idx_l_ptr);

// Perform matrix-matrix multiply
num_threads = threads_per_warp * num_rows_l * num_cols_r;
Kernel<DotCsrTransDnsRspWarpKernel, gpu>::Launch(s, num_threads,
data_out, row_flg_out,
data_l.dptr<DType>(), indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(),
data_r.dptr<DType>(), num_cols_r);
});
// Construct the tensors needed for SortByKey
Tensor<gpu, 1, IType> col_idx_copy(col_idx_copy_ptr, Shape1(nnz), s);
Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(total_temp_bytes), s);

int num_bits = log2i(num_cols_l - 1);
SortByKey(col_idx_copy, original_idx, true, &temp_storage, 0, num_bits);

// over-allocate aux indices
ret->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnz));
// compute unique indices
IType* ret_idx_ptr = ret->aux_data(rowsparse::kIdx).dptr<IType>();
cub::DeviceSelect::Unique(temp_storage_ptr, unique_temp_bytes, col_idx_copy_ptr, ret_idx_ptr,
nnr_ptr, nnz, stream);
// retrieve num non-zero rows
size_t nnr = 0;
CUDA_CALL(cudaMemcpy(&nnr, nnr_ptr, nnr_bytes, cudaMemcpyDeviceToHost));
// allocate data
ret->CheckAndAllocData(mshadow::Shape2(nnz, num_cols_r));
// generate lookup table
Kernel<MarkLookupTable, gpu>::Launch(s, nnr, lookup_table_ptr, ret_idx_ptr);

// Scatter csr indptr to row id
Kernel<CsrRowScatterKernel, gpu>::Launch(
s, num_rows_l, indptr_l.dptr<IType>(), row_idx_ptr, num_rows_l);

Kernel<DotCsrTransDnsRspKernel, gpu>::Launch(s, nnz * num_cols_r,
ret->data().dptr<DType>(),
lookup_table_ptr, col_idx_copy_ptr, nnz,
original_idx_ptr, data_r.dptr<DType>(),
data_l.dptr<DType>(),
row_idx_ptr, num_cols_r);

// reshape aux data
ret->set_aux_shape(rowsparse::kIdx, Shape1(nnr));
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns) = rsp yet.";
}
Expand All @@ -786,6 +817,7 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
});
}


/*!
* \brief GPU Impl of dot(csr, rsp1) = rsp2 and dot(csr.T, rsp1) = rsp2
* TODO: Optimize for GPU; this is a baseline implementation providing
Expand Down Expand Up @@ -990,12 +1022,6 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx,
});
}

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
return k;
}

/*
* \brief GPU Impl of dot(dns, csr) = csr
Expand Down
17 changes: 2 additions & 15 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "./indexing_op.h"
#include "./util/tensor_util-inl.cuh"
#include "./util/tensor_util-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -115,20 +116,6 @@ struct AddTakeGradRspDeterministicKernel {
}
};

/*
* \brief the kernel to generate a lookup table for positions of row ids
* \param i thread id
* \param out output table
* \param data the input row id in sorted order
*/
struct mark_lookup_table {
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) {
out[static_cast<nnvm::dim_t>(data[i])] = i;
}
};


template<>
void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
const TBlob& data,
Expand Down Expand Up @@ -252,7 +239,7 @@ void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx,
output.set_aux_shape(kIdx, Shape1(nnr));

// generate lookup table
Kernel<mark_lookup_table, gpu>::Launch(s, nnr, lookup_table, grad_row_idx);
Kernel<MarkLookupTable, gpu>::Launch(s, nnr, lookup_table, grad_row_idx);

// accumulate gradients
DType* grad_data = output.data().dptr<DType>();
Expand Down
1 change: 0 additions & 1 deletion src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
#include "./util/tensor_util-inl.h"
#include "../mxnet_op.h"
#include "./sort_op.h"
#include "./dot-inl.h"
#include "./init_op.h"
#include "./matrix_op-inl.h"
#include "../../engine/openmp.h"
Expand Down
14 changes: 14 additions & 0 deletions src/operator/tensor/util/tensor_util-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ struct FillRspRowIdxKernel {
}
};

/*
* \brief the kernel to generate a lookup table for positions of row ids
* \param i thread id
* \param out output table
* \param data the input row id in sorted order
*/
struct MarkLookupTable {
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) {
out[static_cast<nnvm::dim_t>(data[i])] = i;
}
};


} // namespace op
} // namespace mxnet

Expand Down
Loading

0 comments on commit 97daed8

Please sign in to comment.