Skip to content

Commit cdaf988

Browse files
rokmrkn
authored andcommitted
ARROW-4223: [Python] Support scipy.sparse integration
This is to resolve [ARROW-4223](https://issues.apache.org/jira/browse/ARROW-4223). Closes #4779 from rok/ARROW-4223 and squashes the following commits: eca3885 <Rok> Adding type check to from_scipy. d5484bf <Rok> Fixing scipy->sparse_tensor tests for dtype=f2. 4d7d2b0 <Rok> Implementing review feedback. 22e864d <Rok> Fixing deserialization issue. Rebasing for new tensor type names. 927bf5d <Kenta Murata> Add SparseCSRIndex::Make b06429d <Kenta Murata> Add SparseCOOIndex::Make 90cbadf <Kenta Murata> Extract a common part from ReadSparseTensorPayload and ReadSparseTensor f4d2e1e <Rok> Enabling serialization with pydata/sparse. db1fb5a <Rok> Applying review feedback for python tests. 2bc2534 <Rok> Re-enabling test_sparse_tensor_coo_components_serialization. 11068eb <Rok> Adding from_scipy and to scipy methods to SparseTensorCOO and SparseTensorCSR. b13604d <Rok> Temporarily disabling test_sparse_tensor_csr_components_serialization test. 56df620 <Kenta Murata> Prevent copying buffers on component serialization of a SparseTensor 958b354 <Rok> Changes to GetSparseTensorMessage to enable SparseTensor to components serialization. c54f005 <Rok> Changes to GetSparseTensorMessage. Enabling comparison for SparseTensor roundtrip test. 247cdbd <Rok> Adding scipy.sparse integration. Lead-authored-by: Rok <rok@mihevc.org> Co-authored-by: Kenta Murata <mrkn@mrkn.jp> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 19e9451 commit cdaf988

22 files changed

+976
-167
lines changed

cpp/src/arrow/ipc/metadata_internal.cc

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,39 +1245,54 @@ Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>
12451245
}
12461246
int ndim = static_cast<int>(sparse_tensor->shape()->size());
12471247

1248-
for (int i = 0; i < ndim; ++i) {
1249-
auto dim = sparse_tensor->shape()->Get(i);
1248+
if (shape || dim_names) {
1249+
for (int i = 0; i < ndim; ++i) {
1250+
auto dim = sparse_tensor->shape()->Get(i);
12501251

1251-
shape->push_back(dim->size());
1252-
auto fb_name = dim->name();
1253-
if (fb_name == 0) {
1254-
dim_names->push_back("");
1255-
} else {
1256-
dim_names->push_back(fb_name->str());
1252+
if (shape) {
1253+
shape->push_back(dim->size());
1254+
}
1255+
1256+
if (dim_names) {
1257+
auto fb_name = dim->name();
1258+
if (fb_name == 0) {
1259+
dim_names->push_back("");
1260+
} else {
1261+
dim_names->push_back(fb_name->str());
1262+
}
1263+
}
12571264
}
12581265
}
12591266

1260-
*non_zero_length = sparse_tensor->non_zero_length();
1267+
if (non_zero_length) {
1268+
*non_zero_length = sparse_tensor->non_zero_length();
1269+
}
12611270

1262-
switch (sparse_tensor->sparseIndex_type()) {
1263-
case flatbuf::SparseTensorIndex_SparseTensorIndexCOO:
1264-
*sparse_tensor_format_id = SparseTensorFormat::COO;
1265-
break;
1271+
if (sparse_tensor_format_id) {
1272+
switch (sparse_tensor->sparseIndex_type()) {
1273+
case flatbuf::SparseTensorIndex_SparseTensorIndexCOO:
1274+
*sparse_tensor_format_id = SparseTensorFormat::COO;
1275+
break;
12661276

1267-
case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
1268-
*sparse_tensor_format_id = SparseTensorFormat::CSR;
1269-
break;
1277+
case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
1278+
*sparse_tensor_format_id = SparseTensorFormat::CSR;
1279+
break;
12701280

1271-
default:
1272-
return Status::Invalid("Unrecognized sparse index type");
1281+
default:
1282+
return Status::Invalid("Unrecognized sparse index type");
1283+
}
12731284
}
12741285

12751286
auto type_data = sparse_tensor->type();
12761287
if (type_data == nullptr) {
12771288
return Status::IOError(
12781289
"Type-pointer in custom metadata of flatbuffer-encoded SparseTensor is null.");
12791290
}
1280-
return ConcreteTypeFromFlatbuffer(sparse_tensor->type_type(), type_data, {}, type);
1291+
if (type) {
1292+
return ConcreteTypeFromFlatbuffer(sparse_tensor->type_type(), type_data, {}, type);
1293+
} else {
1294+
return Status::OK();
1295+
}
12811296
}
12821297

12831298
} // namespace internal

cpp/src/arrow/ipc/read_write_test.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,8 +1188,8 @@ void TestSparseTensorRoundTrip<IndexValueType>::CheckSparseTensorRoundTrip(
11881188

11891189
ASSERT_OK(mmap_->Seek(0));
11901190

1191-
ASSERT_OK(WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length,
1192-
default_memory_pool()));
1191+
ASSERT_OK(
1192+
WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));
11931193

11941194
const auto& sparse_index =
11951195
checked_cast<const SparseCOOIndex&>(*sparse_tensor.sparse_index());
@@ -1224,8 +1224,8 @@ void TestSparseTensorRoundTrip<IndexValueType>::CheckSparseTensorRoundTrip(
12241224

12251225
ASSERT_OK(mmap_->Seek(0));
12261226

1227-
ASSERT_OK(WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length,
1228-
default_memory_pool()));
1227+
ASSERT_OK(
1228+
WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));
12291229

12301230
const auto& sparse_index =
12311231
checked_cast<const SparseCSRIndex&>(*sparse_tensor.sparse_index());
@@ -1285,8 +1285,10 @@ TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexRowMajor) {
12851285
0, 2, 0, 0, 2, 2, 1, 0, 1, 1, 0, 3,
12861286
1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
12871287
const int sizeof_index_value = sizeof(c_index_value_type);
1288-
auto si = this->MakeSparseCOOIndex(
1289-
{12, 3}, {sizeof_index_value * 3, sizeof_index_value}, coords_values);
1288+
std::shared_ptr<SparseCOOIndex> si;
1289+
ASSERT_OK(SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
1290+
{sizeof_index_value * 3, sizeof_index_value},
1291+
Buffer::Wrap(coords_values), &si));
12901292

12911293
std::vector<int64_t> shape = {2, 3, 4};
12921294
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
@@ -1328,8 +1330,10 @@ TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexColumnMajor) {
13281330
0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2,
13291331
0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
13301332
const int sizeof_index_value = sizeof(c_index_value_type);
1331-
auto si = this->MakeSparseCOOIndex(
1332-
{12, 3}, {sizeof_index_value, sizeof_index_value * 12}, coords_values);
1333+
std::shared_ptr<SparseCOOIndex> si;
1334+
ASSERT_OK(SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
1335+
{sizeof_index_value, sizeof_index_value * 12},
1336+
Buffer::Wrap(coords_values), &si));
13331337

13341338
std::vector<int64_t> shape = {2, 3, 4};
13351339
std::vector<std::string> dim_names = {"foo", "bar", "baz"};

cpp/src/arrow/ipc/reader.cc

Lines changed: 130 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -933,32 +933,150 @@ Status MakeSparseTensorWithSparseCSRIndex(
933933
return Status::OK();
934934
}
935935

936-
} // namespace
937-
938-
Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
939-
std::shared_ptr<SparseTensor>* out) {
940-
std::shared_ptr<DataType> type;
941-
std::vector<int64_t> shape;
942-
std::vector<std::string> dim_names;
943-
int64_t non_zero_length;
944-
SparseTensorFormat::type sparse_tensor_format_id;
945-
936+
Status ReadSparseTensorMetadata(const Buffer& metadata,
937+
std::shared_ptr<DataType>* out_type,
938+
std::vector<int64_t>* out_shape,
939+
std::vector<std::string>* out_dim_names,
940+
int64_t* out_non_zero_length,
941+
SparseTensorFormat::type* out_format_id,
942+
const flatbuf::SparseTensor** out_fb_sparse_tensor,
943+
const flatbuf::Buffer** out_buffer) {
946944
RETURN_NOT_OK(internal::GetSparseTensorMetadata(
947-
metadata, &type, &shape, &dim_names, &non_zero_length, &sparse_tensor_format_id));
945+
metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));
948946

949947
const flatbuf::Message* message;
950948
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
949+
951950
auto sparse_tensor = message->header_as_SparseTensor();
952951
if (sparse_tensor == nullptr) {
953952
return Status::IOError(
954953
"Header-type of flatbuffer-encoded Message is not SparseTensor.");
955954
}
956-
const flatbuf::Buffer* buffer = sparse_tensor->data();
955+
*out_fb_sparse_tensor = sparse_tensor;
956+
957+
auto buffer = sparse_tensor->data();
957958
if (!BitUtil::IsMultipleOf8(buffer->offset())) {
958959
return Status::Invalid(
959960
"Buffer of sparse index data did not start on 8-byte aligned offset: ",
960961
buffer->offset());
961962
}
963+
*out_buffer = buffer;
964+
965+
return Status::OK();
966+
}
967+
968+
} // namespace
969+
970+
namespace internal {
971+
972+
namespace {
973+
974+
Status GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
975+
size_t* buffer_count) {
976+
switch (format_id) {
977+
case SparseTensorFormat::COO:
978+
*buffer_count = 2;
979+
break;
980+
981+
case SparseTensorFormat::CSR:
982+
*buffer_count = 3;
983+
break;
984+
985+
default:
986+
return Status::Invalid("Unrecognized sparse tensor format");
987+
}
988+
989+
return Status::OK();
990+
}
991+
992+
Status CheckSparseTensorBodyBufferCount(
993+
const IpcPayload& payload, SparseTensorFormat::type sparse_tensor_format_id) {
994+
size_t expected_body_buffer_count;
995+
996+
RETURN_NOT_OK(GetSparseTensorBodyBufferCount(sparse_tensor_format_id,
997+
&expected_body_buffer_count));
998+
if (payload.body_buffers.size() != expected_body_buffer_count) {
999+
return Status::Invalid("Invalid body buffer count for a sparse tensor");
1000+
}
1001+
1002+
return Status::OK();
1003+
}
1004+
1005+
} // namespace
1006+
1007+
Status ReadSparseTensorBodyBufferCount(const Buffer& metadata, size_t* buffer_count) {
1008+
SparseTensorFormat::type format_id;
1009+
1010+
RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, nullptr, nullptr,
1011+
nullptr, &format_id));
1012+
return GetSparseTensorBodyBufferCount(format_id, buffer_count);
1013+
}
1014+
1015+
Status ReadSparseTensorPayload(const IpcPayload& payload,
1016+
std::shared_ptr<SparseTensor>* out) {
1017+
std::shared_ptr<DataType> type;
1018+
std::vector<int64_t> shape;
1019+
std::vector<std::string> dim_names;
1020+
int64_t non_zero_length;
1021+
SparseTensorFormat::type sparse_tensor_format_id;
1022+
const flatbuf::SparseTensor* sparse_tensor;
1023+
const flatbuf::Buffer* buffer;
1024+
1025+
RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
1026+
&non_zero_length, &sparse_tensor_format_id,
1027+
&sparse_tensor, &buffer));
1028+
1029+
RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id));
1030+
1031+
switch (sparse_tensor_format_id) {
1032+
case SparseTensorFormat::COO: {
1033+
std::shared_ptr<SparseCOOIndex> sparse_index;
1034+
std::shared_ptr<DataType> indices_type;
1035+
RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
1036+
sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
1037+
RETURN_NOT_OK(SparseCOOIndex::Make(indices_type, shape, non_zero_length,
1038+
payload.body_buffers[0], &sparse_index));
1039+
return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
1040+
non_zero_length, payload.body_buffers[1],
1041+
out);
1042+
}
1043+
1044+
case SparseTensorFormat::CSR: {
1045+
std::shared_ptr<SparseCSRIndex> sparse_index;
1046+
std::shared_ptr<DataType> indptr_type;
1047+
std::shared_ptr<DataType> indices_type;
1048+
RETURN_NOT_OK(internal::GetSparseCSRIndexMetadata(
1049+
sparse_tensor->sparseIndex_as_SparseMatrixIndexCSR(), &indptr_type,
1050+
&indices_type));
1051+
ARROW_CHECK_EQ(indptr_type, indices_type);
1052+
RETURN_NOT_OK(SparseCSRIndex::Make(indices_type, shape, non_zero_length,
1053+
payload.body_buffers[0], payload.body_buffers[1],
1054+
&sparse_index));
1055+
return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
1056+
non_zero_length, payload.body_buffers[2],
1057+
out);
1058+
}
1059+
1060+
default:
1061+
return Status::Invalid("Unsupported sparse index format");
1062+
}
1063+
}
1064+
1065+
} // namespace internal
1066+
1067+
Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
1068+
std::shared_ptr<SparseTensor>* out) {
1069+
std::shared_ptr<DataType> type;
1070+
std::vector<int64_t> shape;
1071+
std::vector<std::string> dim_names;
1072+
int64_t non_zero_length;
1073+
SparseTensorFormat::type sparse_tensor_format_id;
1074+
const flatbuf::SparseTensor* sparse_tensor;
1075+
const flatbuf::Buffer* buffer;
1076+
1077+
RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
1078+
&non_zero_length, &sparse_tensor_format_id,
1079+
&sparse_tensor, &buffer));
9621080

9631081
std::shared_ptr<Buffer> data;
9641082
RETURN_NOT_OK(file->ReadAt(buffer->offset(), buffer->length(), &data));

cpp/src/arrow/ipc/reader.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
#include "arrow/ipc/dictionary.h"
2727
#include "arrow/ipc/message.h"
2828
#include "arrow/ipc/options.h"
29+
#include "arrow/ipc/writer.h"
2930
#include "arrow/record_batch.h"
31+
#include "arrow/sparse_tensor.h"
3032
#include "arrow/util/visibility.h"
3133

3234
namespace arrow {
@@ -286,6 +288,27 @@ Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* ou
286288
ARROW_EXPORT
287289
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out);
288290

291+
namespace internal {
292+
293+
// These internal APIs may change without warning or deprecation
294+
295+
/// \brief EXPERIMENTAL: Read arrow::SparseTensorFormat::type from a metadata
296+
/// \param[in] metadata a Buffer containing the sparse tensor metadata
297+
/// \param[out] buffer_count the returned count of the body buffers
298+
/// \return Status
299+
ARROW_EXPORT
300+
Status ReadSparseTensorBodyBufferCount(const Buffer& metadata, size_t* buffer_count);
301+
302+
/// \brief EXPERIMENTAL: Read arrow::SparseTensor from an IpcPayload
303+
/// \param[in] payload a IpcPayload contains a serialized SparseTensor
304+
/// \param[out] out the returned SparseTensor
305+
/// \return Status
306+
ARROW_EXPORT
307+
Status ReadSparseTensorPayload(const IpcPayload& payload,
308+
std::shared_ptr<SparseTensor>* out);
309+
310+
} // namespace internal
311+
289312
} // namespace ipc
290313
} // namespace arrow
291314

cpp/src/arrow/ipc/writer.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,7 @@ Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* poo
814814
} // namespace internal
815815

816816
Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
817-
int32_t* metadata_length, int64_t* body_length,
818-
MemoryPool* pool) {
817+
int32_t* metadata_length, int64_t* body_length) {
819818
internal::IpcPayload payload;
820819
internal::SparseTensorSerializer writer(0, &payload);
821820
RETURN_NOT_OK(writer.Assemble(sparse_tensor));
@@ -824,6 +823,18 @@ Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* ds
824823
return internal::WriteIpcPayload(payload, IpcOptions::Defaults(), dst, metadata_length);
825824
}
826825

826+
Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* pool,
827+
std::unique_ptr<Message>* out) {
828+
internal::IpcPayload payload;
829+
RETURN_NOT_OK(internal::GetSparseTensorPayload(sparse_tensor, pool, &payload));
830+
831+
const std::shared_ptr<Buffer> metadata = payload.metadata;
832+
const std::shared_ptr<Buffer> buffer = *payload.body_buffers.data();
833+
834+
out->reset(new Message(metadata, buffer));
835+
return Status::OK();
836+
}
837+
827838
Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
828839
// emulates the behavior of Write without actually writing
829840
auto options = IpcOptions::Defaults();
@@ -1029,7 +1040,7 @@ class StreamBookKeeper {
10291040
int64_t position_;
10301041
};
10311042

1032-
/// A IpcPayloadWriter implementation that writes to a IPC stream
1043+
/// A IpcPayloadWriter implementation that writes to an IPC stream
10331044
/// (with an end-of-stream marker)
10341045
class PayloadStreamWriter : public internal::IpcPayloadWriter,
10351046
protected StreamBookKeeper {

0 commit comments

Comments
 (0)