Skip to content

Commit

Permalink
Ensure that excessively large sparse matmuls are not executed
Browse files Browse the repository at this point in the history
on the GPU b/c of 32 bit indexing optimization, and provide a
friendly warning instead of wrong results.
Change: 118988179
  • Loading branch information
David G. Andersen authored and tensorflower-gardener committed Apr 4, 2016
1 parent 5a0e8fb commit cca5d0a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ tf_kernel_libraries(
"serialize_sparse_op",
],
deps = [
":bounds_check",
":fill_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
Expand Down
32 changes: 27 additions & 5 deletions tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"

namespace tensorflow {
Expand Down Expand Up @@ -111,9 +112,24 @@ class SparseTensorDenseMatMulOp : public OpKernel {
}

Tensor scratch;
int nnz = a_values->NumElements();

if (std::is_same<Device, GPUDevice>::value) {
// The GPU implementation is optimized to use 32 bit indexing, so
// give a friendly error to the programmer early on if they exceed.
OP_REQUIRES(
ctx,
FastBoundsCheck(inner_left, std::numeric_limits<int>::max()) &&
FastBoundsCheck(inner_right, std::numeric_limits<int>::max()) &&
FastBoundsCheck(outer_left, std::numeric_limits<int>::max()) &&
FastBoundsCheck(outer_right, std::numeric_limits<int>::max()) &&
FastBoundsCheck(b->NumElements(),
std::numeric_limits<int>::max()) &&
FastBoundsCheck(out->NumElements(),
std::numeric_limits<int>::max()) &&
FastBoundsCheck(a_values->NumElements(),
std::numeric_limits<int>::max()),
errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
const int nnz = static_cast<const int>(a_values->NumElements());
// Need nnz length vec scratch space on the GPU.
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({nnz}), &scratch));
Expand Down Expand Up @@ -207,6 +223,7 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
typename TTypes<T>::Vec scratch) {
const std::size_t nnz = a_values.size();
const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
const int lhs_index_a = ADJ_A ? 1 : 0;
const int rhs_index_a = ADJ_A ? 0 : 1;

Expand All @@ -220,8 +237,10 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
// Disable vectorization if the RHS of output is too small
auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
for (std::size_t i = 0; i < nnz; ++i) {
const int64 m = a_indices(i, lhs_index_a);
const int64 k = a_indices(i, rhs_index_a);
const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
CHECK_LT(k, lhs_right);
CHECK_LT(m, out.dimension(0));
const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
for (std::size_t n = 0; n < rhs_right; ++n) {
const T b_value = maybe_adjoint_b(k, n);
Expand All @@ -230,15 +249,18 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
}
} else {
for (std::size_t i = 0; i < nnz; ++i) {
const int64 m = a_indices(i, lhs_index_a);
const int64 k = a_indices(i, rhs_index_a);
const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i);
CHECK_LT(m, out.dimension(0));
if (ADJ_B) {
CHECK_LT(k, b.dimension(1));
out.template chip<0>(m) +=
b.template chip<1>(k).unaryExpr(
Eigen::internal::scalar_conjugate_op<T>()) *
a_value;
} else {
CHECK_LT(k, b.dimension(0));
out.template chip<0>(m) += b.template chip<0>(k) * a_value;
}
}
Expand Down
20 changes: 17 additions & 3 deletions tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SparseTensorDenseMatMulGPUGenerator {
rhs_index_a_(ADJ_A ? 0 : 1),
a_indices_(a_indices),
a_values_(a_values),
lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)),
maybe_adjoint_b_(
functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit,
ADJ_B>(b)) {}
Expand All @@ -49,9 +50,21 @@ class SparseTensorDenseMatMulGPUGenerator {
#ifdef __CUDA_ARCH__
const int j = j_and_ix[0];
const int ix = j_and_ix[1];
const int m = a_indices_(ix, lhs_index_a_);
const int k = a_indices_(ix, rhs_index_a_);
const T b_value = maybe_adjoint_b_(k, j);
int m = a_indices_(ix, lhs_index_a_);
int k = a_indices_(ix, rhs_index_a_);
assert(k < lhs_right_size);
assert(m < out_.dimension(0));
// If asserts are disabled, the caller is violating the sparse
// tensor index contract, and so we return invalid results.
// Force returning NaNs to try to signal that something is amiss.
T b_value;
if (k >= lhs_right_size || m >= out_.dimension(0)) {
m = 0;
k = 0;
b_value = std::numeric_limits<T>::quiet_NaN();
} else {
b_value = maybe_adjoint_b_(k, j);
}
atomicAdd(&out_(m, j), a_values_(ix) * b_value);
#else
assert(false && "This should only be run on the device");
Expand All @@ -66,6 +79,7 @@ class SparseTensorDenseMatMulGPUGenerator {
const int rhs_index_a_;
TTypes<const int64, 2>::Tensor32Bit a_indices_;
typename TTypes<const T, 1>::Tensor32Bit a_values_;
const int lhs_right_size;
functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
maybe_adjoint_b_;
};
Expand Down

0 comments on commit cca5d0a

Please sign in to comment.