Skip to content

Commit c54f005

Browse files
committed
Changes to GetSparseTensorMessage. Enabling comparison for SparseTensor roundtrip test.
1 parent 247cdbd commit c54f005

File tree

3 files changed

+38
-18
lines changed

3 files changed

+38
-18
lines changed

cpp/src/arrow/ipc/writer.cc

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -725,23 +725,6 @@ Status GetTensorMessage(const Tensor& tensor, MemoryPool* pool,
725725
return Status::OK();
726726
}
727727

728-
Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* pool,
729-
std::unique_ptr<Message>* out) {
730-
const SparseTensor* sparse_tensor_to_write = &sparse_tensor;
731-
std::unique_ptr<SparseTensor> temp_sparse_tensor;
732-
733-
const auto& type = checked_cast<const FixedWidthType&>(*sparse_tensor.type());
734-
const int elem_size = type.bit_width() / 8;
735-
int64_t body_length = sparse_tensor.size() * elem_size;
736-
const std::vector<internal::BufferMetadata> buffers;
737-
738-
std::shared_ptr<Buffer> metadata;
739-
RETURN_NOT_OK(internal::WriteSparseTensorMessage(*sparse_tensor_to_write, body_length,
740-
buffers, &metadata));
741-
out->reset(new Message(metadata, sparse_tensor_to_write->data()));
742-
return Status::OK();
743-
}
744-
745728
namespace internal {
746729

747730
class SparseTensorSerializer {
@@ -840,6 +823,30 @@ Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* ds
840823
return internal::WriteIpcPayload(payload, IpcOptions::Defaults(), dst, metadata_length);
841824
}
842825

826+
Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* pool,
827+
std::unique_ptr<Message>* out) {
828+
internal::IpcPayload payload;
829+
RETURN_NOT_OK(internal::GetSparseTensorPayload(sparse_tensor, pool, &payload));
830+
831+
const std::shared_ptr<Buffer> metadata = payload.metadata;
832+
const std::shared_ptr<Buffer> buffer = *payload.body_buffers.data();
833+
834+
out->reset(new Message(metadata, buffer));
835+
return Status::OK();
836+
}
837+
838+
Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary,
839+
int64_t buffer_start_offset, io::OutputStream* dst,
840+
int32_t* metadata_length, int64_t* body_length, MemoryPool* pool) {
841+
auto options = IpcOptions::Defaults();
842+
internal::IpcPayload payload;
843+
RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, dictionary, options, pool, &payload));
844+
845+
// The body size is computed in the payload
846+
*body_length = payload.body_length;
847+
return internal::WriteIpcPayload(payload, dst, metadata_length);
848+
}
849+
843850
Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
844851
// emulates the behavior of Write without actually writing
845852
auto options = IpcOptions::Defaults();
@@ -1045,7 +1052,14 @@ class StreamBookKeeper {
10451052
int64_t position_;
10461053
};
10471054

1055+
<<<<<<< HEAD
10481056
/// A IpcPayloadWriter implementation that writes to a IPC stream
1057+
=======
1058+
// End of stream marker
1059+
constexpr int32_t kEos = 0;
1060+
1061+
/// A IpcPayloadWriter implementation that writes to an IPC stream
1062+
>>>>>>> Changes to GetSparseTensorMessage. Enabling comparison for SparseTensor roundtrip test.
10491063
/// (with an end-of-stream marker)
10501064
class PayloadStreamWriter : public internal::IpcPayloadWriter,
10511065
protected StreamBookKeeper {

cpp/src/arrow/ipc/writer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadat
292292
///
293293
/// The message is written out as followed:
294294
/// \code
295-
/// <metadata size> <metadata> <sparse tensor data> <sparse tensor index>
295+
/// <metadata size> <metadata> <sparse index> <sparse tensor body>
296296
/// \endcode
297297
///
298298
/// \param[in] sparse_tensor the SparseTensor to write

python/pyarrow/tests/test_serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def assert_equal(obj1, obj2):
117117
assert obj1.equals(obj2)
118118
elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor):
119119
assert obj1.equals(obj2)
120+
elif isinstance(obj1, pa.SparseTensorCOO) and \
121+
isinstance(obj2, pa.SparseTensorCOO):
122+
assert obj1.equals(obj2)
123+
elif isinstance(obj1, pa.SparseTensorCSR) and \
124+
isinstance(obj2, pa.SparseTensorCSR):
125+
assert obj1.equals(obj2)
120126
elif isinstance(obj1, pa.RecordBatch) and isinstance(obj2, pa.RecordBatch):
121127
assert obj1.equals(obj2)
122128
elif isinstance(obj1, pa.Table) and isinstance(obj2, pa.Table):

0 commit comments

Comments
 (0)