Skip to content

Commit b8aeb79

Browse files
mrknwesm
authored andcommitted
ARROW-854: [Format] Add tentative SparseTensor format
I'm interested in making a language-agnostic sparse tensor format. I believe one of the suitable places to do this is Apache Arrow, so let me propose my idea of this here. First of all, I found that there is no common memory layout of sparse tensor representations in my investigation. It means we need some kinds of conversion to share sparse tensors among different systems even if the data format is logically the same. It is the same situation as dataframe, and this is the reason why I believe Apache Arrow is the suitable place. There are many formats to represent a sparse tensor. Most of them are specialized for a matrix, which has two dimensions. There are few formats for general sparse tensor with more than two dimensions. I think the COO format is suitable to start because COO can handle any dimensions, and many systems support the COO format. In my investigation, the systems support COO are SciPy, dask, pydata/sparse, TensorFlow, and PyTorch. Additionally, CSR format for matrices may also be good to support at the first time. The reason is that CSR format is efficient to extract row slices, that may be important for extracting samples from tidy data, and it is supported by SciPy, MXNet, and R's Matrix library. I add my prototype definition of SparseTensor format in this pull-request. I designed this prototype format to be extensible so that we can support additional sparse formats. I think we at least need to support additional sparse tensor format for more than two dimensions in addition to COO so we will need this extensibility. Author: Kenta Murata <mrkn@mrkn.jp> Closes #2546 from mrkn/sparse_tensor_proposal and squashes the following commits: 148bff8 <Kenta Murata> make format d57e56f <Kenta Murata> Merge sparse_tensor_format.h into sparse_tensor.h 880bbc4 <Kenta Murata> Rename too-verbose function name c83ea6a <Kenta Murata> Add type aliases of sparse tensor types 90e8b31 <Kenta Murata> Rename sparse tensor classes 07a6518 <Kenta Murata> Use substitution instead of constructor call 37a0a14 <Kenta Murata> Remove needless function declaration 97e85bd <Kenta Murata> Use std::make_shared 3dd434c <Kenta Murata> Capitalize member function name 6ef6ad0 <Kenta Murata> Apply code formatter 6f29158 <Kenta Murata> Mark APIs for sparse tensor as EXPERIMENTAL ff3ea71 <Kenta Murata> Rename length to non_zero_length in SparseTensor f782303 <Kenta Murata> Return Status::IOError instead of DCHECK if message header type is not matched 7e814de <Kenta Murata> Put EXPERIMENTAL markn in comments 357860d <Kenta Murata> Fix typo in comments 43d8eea <Kenta Murata> Fix coding style 99b1d1d <Kenta Murata> Add missing ARROW_EXPORT specifiers 401ae80 <Kenta Murata> Fix SparseCSRIndex::ToString and add tests 9e457ac <Kenta Murata> Remove needless virtual specifiers 3b1db7d <Kenta Murata> Add SparseTensorBase::Equals d6a8c38 <Kenta Murata> Unify Tensor.fbs and SparseTensor.fbs b3a62eb <Kenta Murata> Fix format 6bc9e29 <Kenta Murata> Support IPC read and write of SparseTensor 1d90427 <Kenta Murata> Fix format 51a83bf <Kenta Murata> Add SparseTensorFormat 93c03ad <Kenta Murata> Add SparseIndex::ToString() 021b46b <Kenta Murata> Add SparseTensorBase ed3984d <Kenta Murata> Add SparseIndex::format_type 4251b4d <Kenta Murata> Add SparseCSRIndex 433c9b4 <Kenta Murata> Change COO index matrix to column-major in a format description 392a25b <Kenta Murata> Implement SparseTensor and SparseCOOIndex b24f3c3 <Kenta Murata> Insert additional padding in sparse tensor format c508db0 <Kenta Murata> Write sparse tensor format in IPC.md 2b50040 <Kenta Murata> Add an example of the CSR format in comment 76c56dd <Kenta Murata> Make indptr of CSR a buffer d7e653f <Kenta Murata> Add an example of COO format in comment 866b2c1 <Kenta Murata> Add header comments in SparseTensor.fbs aa9b8a4 <Kenta Murata> Add SparseTensor.fbs in FBS_SRC 1f16ffe <Kenta Murata> Fix syntax error in SparseTensor.fbs c3bc6ed <Kenta Murata> Add tentative SparseTensor format
1 parent 6b496f7 commit b8aeb79

File tree

19 files changed

+1661
-3
lines changed

19 files changed

+1661
-3
lines changed

cpp/src/arrow/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ set(ARROW_SRCS
8383
table.cc
8484
table_builder.cc
8585
tensor.cc
86+
sparse_tensor.cc
8687
type.cc
8788
visitor.cc
8889

@@ -286,6 +287,7 @@ ADD_ARROW_TEST(type-test)
286287
ADD_ARROW_TEST(table-test)
287288
ADD_ARROW_TEST(table_builder-test)
288289
ADD_ARROW_TEST(tensor-test)
290+
ADD_ARROW_TEST(sparse_tensor-test)
289291

290292
ADD_ARROW_BENCHMARK(builder-benchmark)
291293
ADD_ARROW_BENCHMARK(column-benchmark)

cpp/src/arrow/compare.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include "arrow/array.h"
3232
#include "arrow/buffer.h"
33+
#include "arrow/sparse_tensor.h"
3334
#include "arrow/status.h"
3435
#include "arrow/tensor.h"
3536
#include "arrow/type.h"
@@ -782,6 +783,98 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
782783
return are_equal;
783784
}
784785

786+
namespace {
787+
788+
template <typename LeftSparseIndexType, typename RightSparseIndexType>
789+
struct SparseTensorEqualsImpl {
790+
static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
791+
const SparseTensorImpl<RightSparseIndexType>& right) {
792+
// TODO(mrkn): should we support the equality among different formats?
793+
return false;
794+
}
795+
};
796+
797+
template <typename SparseIndexType>
798+
struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
799+
static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
800+
const SparseTensorImpl<SparseIndexType>& right) {
801+
DCHECK(left.type()->id() == right.type()->id());
802+
DCHECK(left.shape() == right.shape());
803+
DCHECK(left.non_zero_length() == right.non_zero_length());
804+
805+
const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
806+
const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
807+
808+
if (!left_index.Equals(right_index)) {
809+
return false;
810+
}
811+
812+
const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
813+
const int byte_width = size_meta.bit_width() / CHAR_BIT;
814+
DCHECK_GT(byte_width, 0);
815+
816+
const uint8_t* left_data = left.data()->data();
817+
const uint8_t* right_data = right.data()->data();
818+
819+
return memcmp(left_data, right_data,
820+
static_cast<size_t>(byte_width * left.non_zero_length()));
821+
}
822+
};
823+
824+
template <typename SparseIndexType>
825+
inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
826+
const SparseTensor& right) {
827+
switch (right.format_id()) {
828+
case SparseTensorFormat::COO: {
829+
const auto& right_coo =
830+
checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
831+
return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left,
832+
right_coo);
833+
}
834+
835+
case SparseTensorFormat::CSR: {
836+
const auto& right_csr =
837+
checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
838+
return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left,
839+
right_csr);
840+
}
841+
842+
default:
843+
return false;
844+
}
845+
}
846+
847+
} // namespace
848+
849+
bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
850+
if (&left == &right) {
851+
return true;
852+
} else if (left.type()->id() != right.type()->id()) {
853+
return false;
854+
} else if (left.size() == 0) {
855+
return true;
856+
} else if (left.shape() != right.shape()) {
857+
return false;
858+
} else if (left.non_zero_length() != right.non_zero_length()) {
859+
return false;
860+
}
861+
862+
switch (left.format_id()) {
863+
case SparseTensorFormat::COO: {
864+
const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
865+
return SparseTensorEqualsImplDispatch(left_coo, right);
866+
}
867+
868+
case SparseTensorFormat::CSR: {
869+
const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
870+
return SparseTensorEqualsImplDispatch(left_csr, right);
871+
}
872+
873+
default:
874+
return false;
875+
}
876+
}
877+
785878
bool TypeEquals(const DataType& left, const DataType& right) {
786879
bool are_equal;
787880
// The arrays are the same object

cpp/src/arrow/compare.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@ namespace arrow {
2929
class Array;
3030
class DataType;
3131
class Tensor;
32+
class SparseTensor;
3233

3334
/// Returns true if the arrays are exactly equal
3435
bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right);
3536

3637
bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right);
3738

39+
/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal
40+
bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left, const SparseTensor& right);
41+
3842
/// Returns true if the arrays are approximately equal. For non-floating point
3943
/// types, this is equivalent to ArrayEquals(left, right)
4044
bool ARROW_EXPORT ArrayApproxEquals(const Array& left, const Array& right);

cpp/src/arrow/ipc/message.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class Message::MessageImpl {
6363
return Message::RECORD_BATCH;
6464
case flatbuf::MessageHeader_Tensor:
6565
return Message::TENSOR;
66+
case flatbuf::MessageHeader_SparseTensor:
67+
return Message::SPARSE_TENSOR;
6668
default:
6769
return Message::NONE;
6870
}

cpp/src/arrow/ipc/message.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ constexpr int kMaxNestingDepth = 64;
7070
/// \brief An IPC message including metadata and body
7171
class ARROW_EXPORT Message {
7272
public:
73-
enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH, TENSOR };
73+
enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH, TENSOR, SPARSE_TENSOR };
7474

7575
/// \brief Construct message, but do not validate
7676
///

cpp/src/arrow/ipc/metadata-internal.cc

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "arrow/ipc/Tensor_generated.h" // IWYU pragma: keep
3232
#include "arrow/ipc/message.h"
3333
#include "arrow/ipc/util.h"
34+
#include "arrow/sparse_tensor.h"
3435
#include "arrow/status.h"
3536
#include "arrow/tensor.h"
3637
#include "arrow/type.h"
@@ -50,6 +51,7 @@ using DictionaryOffset = flatbuffers::Offset<flatbuf::DictionaryEncoding>;
5051
using FieldOffset = flatbuffers::Offset<flatbuf::Field>;
5152
using KeyValueOffset = flatbuffers::Offset<flatbuf::KeyValue>;
5253
using RecordBatchOffset = flatbuffers::Offset<flatbuf::RecordBatch>;
54+
using SparseTensorOffset = flatbuffers::Offset<flatbuf::SparseTensor>;
5355
using Offset = flatbuffers::Offset<void>;
5456
using FBString = flatbuffers::Offset<flatbuffers::String>;
5557

@@ -781,6 +783,106 @@ Status WriteTensorMessage(const Tensor& tensor, int64_t buffer_start_offset,
781783
body_length, out);
782784
}
783785

786+
Status MakeSparseTensorIndexCOO(FBB& fbb, const SparseCOOIndex& sparse_index,
787+
const std::vector<BufferMetadata>& buffers,
788+
flatbuf::SparseTensorIndex* fb_sparse_index_type,
789+
Offset* fb_sparse_index, size_t* num_buffers) {
790+
*fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseTensorIndexCOO;
791+
const BufferMetadata& indices_metadata = buffers[0];
792+
flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length);
793+
*fb_sparse_index = flatbuf::CreateSparseTensorIndexCOO(fbb, &indices).Union();
794+
*num_buffers = 1;
795+
return Status::OK();
796+
}
797+
798+
Status MakeSparseMatrixIndexCSR(FBB& fbb, const SparseCSRIndex& sparse_index,
799+
const std::vector<BufferMetadata>& buffers,
800+
flatbuf::SparseTensorIndex* fb_sparse_index_type,
801+
Offset* fb_sparse_index, size_t* num_buffers) {
802+
*fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseMatrixIndexCSR;
803+
const BufferMetadata& indptr_metadata = buffers[0];
804+
const BufferMetadata& indices_metadata = buffers[1];
805+
flatbuf::Buffer indptr(indptr_metadata.offset, indptr_metadata.length);
806+
flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length);
807+
*fb_sparse_index = flatbuf::CreateSparseMatrixIndexCSR(fbb, &indptr, &indices).Union();
808+
*num_buffers = 2;
809+
return Status::OK();
810+
}
811+
812+
Status MakeSparseTensorIndex(FBB& fbb, const SparseIndex& sparse_index,
813+
const std::vector<BufferMetadata>& buffers,
814+
flatbuf::SparseTensorIndex* fb_sparse_index_type,
815+
Offset* fb_sparse_index, size_t* num_buffers) {
816+
switch (sparse_index.format_id()) {
817+
case SparseTensorFormat::COO:
818+
RETURN_NOT_OK(MakeSparseTensorIndexCOO(
819+
fbb, checked_cast<const SparseCOOIndex&>(sparse_index), buffers,
820+
fb_sparse_index_type, fb_sparse_index, num_buffers));
821+
break;
822+
823+
case SparseTensorFormat::CSR:
824+
RETURN_NOT_OK(MakeSparseMatrixIndexCSR(
825+
fbb, checked_cast<const SparseCSRIndex&>(sparse_index), buffers,
826+
fb_sparse_index_type, fb_sparse_index, num_buffers));
827+
break;
828+
829+
default:
830+
std::stringstream ss;
831+
ss << "Unsupporoted sparse tensor format:: " << sparse_index.ToString()
832+
<< std::endl;
833+
return Status::NotImplemented(ss.str());
834+
}
835+
836+
return Status::OK();
837+
}
838+
839+
Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor, int64_t body_length,
840+
const std::vector<BufferMetadata>& buffers,
841+
SparseTensorOffset* offset) {
842+
flatbuf::Type fb_type_type;
843+
Offset fb_type;
844+
RETURN_NOT_OK(
845+
TensorTypeToFlatbuffer(fbb, *sparse_tensor.type(), &fb_type_type, &fb_type));
846+
847+
using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
848+
std::vector<TensorDimOffset> dims;
849+
for (int i = 0; i < sparse_tensor.ndim(); ++i) {
850+
FBString name = fbb.CreateString(sparse_tensor.dim_name(i));
851+
dims.push_back(flatbuf::CreateTensorDim(fbb, sparse_tensor.shape()[i], name));
852+
}
853+
854+
auto fb_shape = fbb.CreateVector(dims);
855+
856+
flatbuf::SparseTensorIndex fb_sparse_index_type;
857+
Offset fb_sparse_index;
858+
size_t num_index_buffers = 0;
859+
RETURN_NOT_OK(MakeSparseTensorIndex(fbb, *sparse_tensor.sparse_index(), buffers,
860+
&fb_sparse_index_type, &fb_sparse_index,
861+
&num_index_buffers));
862+
863+
const BufferMetadata& data_metadata = buffers[num_index_buffers];
864+
flatbuf::Buffer data(data_metadata.offset, data_metadata.length);
865+
866+
const int64_t non_zero_length = sparse_tensor.non_zero_length();
867+
868+
*offset =
869+
flatbuf::CreateSparseTensor(fbb, fb_type_type, fb_type, fb_shape, non_zero_length,
870+
fb_sparse_index_type, fb_sparse_index, &data);
871+
872+
return Status::OK();
873+
}
874+
875+
Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_length,
876+
const std::vector<BufferMetadata>& buffers,
877+
std::shared_ptr<Buffer>* out) {
878+
FBB fbb;
879+
SparseTensorOffset fb_sparse_tensor;
880+
RETURN_NOT_OK(
881+
MakeSparseTensor(fbb, sparse_tensor, body_length, buffers, &fb_sparse_tensor));
882+
return WriteFBMessage(fbb, flatbuf::MessageHeader_SparseTensor,
883+
fb_sparse_tensor.Union(), body_length, out);
884+
}
885+
784886
Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length,
785887
const std::vector<FieldMetadata>& nodes,
786888
const std::vector<BufferMetadata>& buffers,
@@ -933,6 +1035,52 @@ Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type
9331035
return TypeFromFlatbuffer(tensor->type_type(), tensor->type(), {}, type);
9341036
}
9351037

1038+
Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
1039+
std::vector<int64_t>* shape,
1040+
std::vector<std::string>* dim_names,
1041+
int64_t* non_zero_length,
1042+
SparseTensorFormat::type* sparse_tensor_format_id) {
1043+
auto message = flatbuf::GetMessage(metadata.data());
1044+
if (message->header_type() != flatbuf::MessageHeader_SparseTensor) {
1045+
return Status::IOError("Header of flatbuffer-encoded Message is not SparseTensor.");
1046+
}
1047+
if (message->header() == nullptr) {
1048+
return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
1049+
}
1050+
1051+
auto sparse_tensor = reinterpret_cast<const flatbuf::SparseTensor*>(message->header());
1052+
int ndim = static_cast<int>(sparse_tensor->shape()->size());
1053+
1054+
for (int i = 0; i < ndim; ++i) {
1055+
auto dim = sparse_tensor->shape()->Get(i);
1056+
1057+
shape->push_back(dim->size());
1058+
auto fb_name = dim->name();
1059+
if (fb_name == 0) {
1060+
dim_names->push_back("");
1061+
} else {
1062+
dim_names->push_back(fb_name->str());
1063+
}
1064+
}
1065+
1066+
*non_zero_length = sparse_tensor->non_zero_length();
1067+
1068+
switch (sparse_tensor->sparseIndex_type()) {
1069+
case flatbuf::SparseTensorIndex_SparseTensorIndexCOO:
1070+
*sparse_tensor_format_id = SparseTensorFormat::COO;
1071+
break;
1072+
1073+
case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
1074+
*sparse_tensor_format_id = SparseTensorFormat::CSR;
1075+
break;
1076+
1077+
default:
1078+
return Status::Invalid("Unrecognized sparse index type");
1079+
}
1080+
1081+
return TypeFromFlatbuffer(sparse_tensor->type_type(), sparse_tensor->type(), {}, type);
1082+
}
1083+
9361084
// ----------------------------------------------------------------------
9371085
// Implement message writing
9381086

cpp/src/arrow/ipc/metadata-internal.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
#include "arrow/ipc/dictionary.h" // IYWU pragma: keep
3434
#include "arrow/ipc/message.h"
3535
#include "arrow/memory_pool.h"
36+
#include "arrow/sparse_tensor.h"
3637
#include "arrow/status.h"
3738

3839
namespace arrow {
3940

4041
class DataType;
4142
class Schema;
4243
class Tensor;
44+
class SparseTensor;
4345

4446
namespace flatbuf = org::apache::arrow::flatbuf;
4547

@@ -103,6 +105,12 @@ Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type
103105
std::vector<int64_t>* shape, std::vector<int64_t>* strides,
104106
std::vector<std::string>* dim_names);
105107

108+
// EXPERIMENTAL: Extracting metadata of a sparse tensor from the message
109+
Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
110+
std::vector<int64_t>* shape,
111+
std::vector<std::string>* dim_names, int64_t* length,
112+
SparseTensorFormat::type* sparse_tensor_format_id);
113+
106114
/// Write a serialized message metadata with a length-prefix and padding to an
107115
/// 8-byte offset. Does not make assumptions about whether the stream is
108116
/// aligned already
@@ -137,6 +145,10 @@ Status WriteRecordBatchMessage(const int64_t length, const int64_t body_length,
137145
Status WriteTensorMessage(const Tensor& tensor, const int64_t buffer_start_offset,
138146
std::shared_ptr<Buffer>* out);
139147

148+
Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_length,
149+
const std::vector<BufferMetadata>& buffers,
150+
std::shared_ptr<Buffer>* out);
151+
140152
Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dictionaries,
141153
const std::vector<FileBlock>& record_batches,
142154
DictionaryMemo* dictionary_memo, io::OutputStream* out);

0 commit comments

Comments
 (0)