Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/graphcycles/graphcycles.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"

#include <algorithm>
#include <limits>
#include <unordered_set>

#include "absl/algorithm/container.h"
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/graphcycles/graphcycles.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_

#include <vector>
#include <limits>

// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.matmul,
mkl_op_registry::GetMklOpName(csinfo_.matmul),
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.leakyrelu,
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
Expand Down
128 changes: 90 additions & 38 deletions tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ limitations under the License.

#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"

#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/SparseCore"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/threadpool.h"

namespace tensorflow {

Expand Down Expand Up @@ -124,20 +128,20 @@ class SparseTensorDenseMatMulOp : public OpKernel {
return;
}

functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), out->flat<T>());
if (a_values->NumElements() == 0 || b->NumElements() == 0) {
// If a has shape [x, 0] and b has shape [0, y], the
// output shape is [x, y] where x and y are non-zero, so we fill
// the output with zeros.
functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), out->flat<T>());
return;
}

#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
Device, T, Tindices, ADJ_A, \
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
ADJ_B>::Compute(ctx, ctx->eigen_device<Device>(), out->matrix<T>(), \
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
b->matrix<T>()); \
OP_REQUIRES_OK(ctx, functor_status); \
Expand Down Expand Up @@ -183,7 +187,7 @@ namespace functor {
template <> \
Status SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, \
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
ADJ_B>::Compute(OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::Matrix out, \
TTypes<Tindices>::ConstMatrix a_indices, \
typename TTypes<T>::ConstVec a_values, \
typename TTypes<T>::ConstMatrix b); \
Expand Down Expand Up @@ -245,8 +249,10 @@ template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
// Vectorize certain operations above this size.
static constexpr std::size_t kNumVectorize = 32;

static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using MatrixMap = Eigen::Map<Matrix>;
static Status Compute(OpKernelContext* ctx, const CPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
typename TTypes<T>::ConstMatrix b) {
Expand All @@ -255,8 +261,17 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
const int lhs_index_a = ADJ_A ? 1 : 0;
const int rhs_index_a = ADJ_A ? 0 : 1;
static constexpr int32 kMinShards = 10;

// out.setZero();

out.setZero();
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
const int32 num_threads = worker_threads.num_threads;
const int32 total_size = out.dimension(0);
VLOG(3) << "nnz=" << nnz << ", lhs_index_a=" << lhs_index_a << ", rhs_index_a=" <<
rhs_index_a << ", lhs_right=" << lhs_right << ", rhs_right=" <<
rhs_right << ", total_size=" << total_size;
const int64 block_size = std::max(4096, int32(total_size) / num_threads);

// TODO(ebrevdo): After many failed experiments, can't find a multi-threaded
// approach that achieves the performance of the single threaded
Expand All @@ -266,51 +281,88 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
// Disable vectorization if the RHS of output is too small
auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);

for (std::size_t i = 0; i < nnz; ++i) {
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
if (!FastBoundsCheck(k, lhs_right)) {
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
auto lambda = [&](Tindices block_begin, Tindices block_end, int tid) {
VLOG(3) << "block_begin=" << block_begin << ", block_end=" <<
block_end << ", tid=" << tid;

for (std::size_t i = 0; i < nnz; ++i) {
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
if (m < block_begin || m >= block_end ) {
continue;
}
if (!FastBoundsCheck(k, lhs_right)) {
LOG(ERROR) << KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
continue;
}
if (!FastBoundsCheck(m, out.dimension(0))) {
LOG(ERROR) << MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
continue;
}
const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
for (std::size_t n = 0; n < rhs_right; ++n) {
const T b_value = maybe_adjoint_b(k, n);
out(m, n) += a_value * b_value;
}
}
if (!FastBoundsCheck(m, out.dimension(0))) {
return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
}
const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
for (std::size_t n = 0; n < rhs_right; ++n) {
const T b_value = maybe_adjoint_b(k, n);
out(m, n) += a_value * b_value;
}
}
return;
};
worker_threads.workers->ParallelForWithWorkerId(
total_size /* total */,
thread::ThreadPool::SchedulingParams(
thread::ThreadPool::SchedulingStrategy::
kFixedBlockSize /* strategy */,
absl::nullopt /* cost_per_unit */, block_size),
lambda
);

} else {
// Vectorization via Eigen.
const int b_chip_index = ADJ_B ? 1 : 0;

#define LOOP_NNZ(b_passed) \
for (std::size_t i = 0; i < nnz; ++i) { \
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \
if (!FastBoundsCheck(k, lhs_right)) { \
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
} \
if (!FastBoundsCheck(m, out.dimension(0))) { \
return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
} \
out.template chip<0>(m) += \
b_passed.template chip<b_chip_index>(k) * a_value; \
}
#define LOOP_NNZ_PARALLEL(b_passed) \
auto lambda = [&](Tindices block_begin, Tindices block_end, int tid) { \
VLOG(3) << "block_begin=" << block_begin << ", block_end=" << \
block_end << ", tid=" << tid; \
for (long long int i = 0; i < nnz; ++i) { \
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
if (m < block_begin || m >= block_end ) { \
continue; \
} \
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
if (!FastBoundsCheck(k, lhs_right)) { \
LOG(ERROR) << KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
continue; \
} \
if (!FastBoundsCheck(m, out.dimension(0))) { \
LOG(ERROR) << MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
continue; \
} \
const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i); \
out.template chip<0>(m) += b.template chip<b_chip_index>(k) * a_value; \
} \
return; \
}; \
worker_threads.workers->ParallelForWithWorkerId( \
total_size /* total */, \
thread::ThreadPool::SchedulingParams( \
thread::ThreadPool::SchedulingStrategy:: \
kFixedBlockSize /* strategy */, \
absl::nullopt /* cost_per_unit */, block_size), \
lambda \
); \

if (ADJ_B) {
// Perform transpose and conjugation on B once, since we chip out B's
// columns in the nnz loop.
Eigen::array<int, 2> shuffle(1, 0); // preserve dimension order
Eigen::Tensor<T, 2, Eigen::ColMajor> col_major_conj_b =
b.swap_layout().shuffle(shuffle).conjugate();
LOOP_NNZ(col_major_conj_b);
LOOP_NNZ_PARALLEL(col_major_conj_b);
} else {
LOOP_NNZ(b);
LOOP_NNZ_PARALLEL(b);
}
#undef LOOP_NNZ
#undef LOOP_NNZ_PARALLEL
}
return Status::OK();
}
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/framework/op_kernel.h"

namespace tensorflow {

Expand All @@ -29,6 +30,7 @@ template <typename Device, typename T, typename Tindices, bool ADJ_A,
bool ADJ_B>
struct SparseTensorDenseMatMulFunctor {
static EIGEN_ALWAYS_INLINE Status Compute(
OpKernelContext* ctx,
const Device& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace functor {
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
static EIGEN_ALWAYS_INLINE Status
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
Compute(OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
typename TTypes<T>::ConstMatrix b) {
Expand Down
44 changes: 42 additions & 2 deletions tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ static Graph* SparseTensorDenseMatmul(int nnz, int m, int k, int n,
BM_SparseTensorDenseMatmul##_##NNZ##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE);

#define BM_SparseTensorDenseMatmul(NNZ, M, K, N, TA, TB) \
BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, cpu); \
BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, gpu);
BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, cpu);

// BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, cpu); \
// BM_SparseTensorDenseMatmulDev(NNZ, M, K, N, TA, TB, gpu);
/*
BM_SparseTensorDenseMatmul(128, 8, 512, 1, false, false);
BM_SparseTensorDenseMatmul(128, 16, 512, 1, false, false);
BM_SparseTensorDenseMatmul(128, 128, 512, 1, false, false);
Expand Down Expand Up @@ -117,4 +119,42 @@ BM_SparseTensorDenseMatmul(16384, 4096, 4096, 4096, false, true);
BM_SparseTensorDenseMatmul(16384, 4096, 4096, 4096, true, false);
BM_SparseTensorDenseMatmul(16384, 4096, 4096, 4096, true, true);

BM_SparseTensorDenseMatmul(10240, 10240, 150000, 16, false, false);
BM_SparseTensorDenseMatmul(10240, 10240, 150000, 16, true, false);
BM_SparseTensorDenseMatmul(20480, 20480, 150000, 16, false, false);
BM_SparseTensorDenseMatmul(40960, 40960, 150000, 16, false, false);
BM_SparseTensorDenseMatmul(81920, 81920, 150000, 16, false, false);
BM_SparseTensorDenseMatmul(163840, 163840, 150000, 16, false, false);

BM_SparseTensorDenseMatmul(10240, 10240, 150000, 32, false, false);
BM_SparseTensorDenseMatmul(10240, 10240, 150000, 32, true, false);
BM_SparseTensorDenseMatmul(20480, 20480, 150000, 32, false, false);
BM_SparseTensorDenseMatmul(40960, 40960, 150000, 32, false, false);
BM_SparseTensorDenseMatmul(81920, 81920, 150000, 32, false, false);
BM_SparseTensorDenseMatmul(163840, 163840, 150000, 32, false, false);

BM_SparseTensorDenseMatmul(10240, 10240, 150000, 64, false, false);
BM_SparseTensorDenseMatmul(10240, 10240, 150000, 64, true, false);
BM_SparseTensorDenseMatmul(20480, 20480, 150000, 64, false, false);
BM_SparseTensorDenseMatmul(40960, 40960, 150000, 64, false, false);
BM_SparseTensorDenseMatmul(81920, 81920, 150000, 64, false, false);
BM_SparseTensorDenseMatmul(81920, 81920, 150000, 64, true, false);
BM_SparseTensorDenseMatmul(163840, 163840, 150000, 64, false, false);
BM_SparseTensorDenseMatmul(163840, 163840, 150000, 64, true, false);

BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 16, false, false);
BM_SparseTensorDenseMatmul(10240, 1250000, 10240, 16, true, false);
BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 32, false, false);
BM_SparseTensorDenseMatmul(10240, 1250000, 10240, 32, true, false);
BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 64, false, false);
BM_SparseTensorDenseMatmul(10240, 1250000, 10240, 64, true, false);
BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 96, false, false);
BM_SparseTensorDenseMatmul(10240, 1250000, 10240, 96, true, false);
*/

BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 192, false, false);
BM_SparseTensorDenseMatmul(10240, 10240, 1250000, 192, true, false);
BM_SparseTensorDenseMatmul(10240, 500000, 1024, 192, false, false);
BM_SparseTensorDenseMatmul(10240, 500000, 1024, 192, true, false);

} // end namespace tensorflow