Skip to content

Commit 90e8b31

Browse files
committed
Rename sparse tensor classes
- SparseTensorBase to SparseTensor - SparseTensor<...> to SparseTensorImpl<...>
1 parent 07a6518 commit 90e8b31

File tree

13 files changed

+103
-103
lines changed

13 files changed

+103
-103
lines changed

cpp/src/arrow/compare.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -787,17 +787,17 @@ namespace {
787787

788788
template <typename LeftSparseIndexType, typename RightSparseIndexType>
789789
struct SparseTensorEqualsImpl {
790-
static bool Compare(const SparseTensor<LeftSparseIndexType>& left,
791-
const SparseTensor<RightSparseIndexType>& right) {
790+
static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
791+
const SparseTensorImpl<RightSparseIndexType>& right) {
792792
// TODO(mrkn): should we support the equality among different formats?
793793
return false;
794794
}
795795
};
796796

797797
template <typename SparseIndexType>
798798
struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
799-
static bool Compare(const SparseTensor<SparseIndexType>& left,
800-
const SparseTensor<SparseIndexType>& right) {
799+
static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
800+
const SparseTensorImpl<SparseIndexType>& right) {
801801
DCHECK(left.type()->id() == right.type()->id());
802802
DCHECK(left.shape() == right.shape());
803803
DCHECK(left.non_zero_length() == right.non_zero_length());
@@ -821,19 +821,19 @@ struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
821821
}
822822
};
823823

824-
template <typename SparseTensorType>
825-
inline bool SparseTensorEqualsImplDispatch(const SparseTensor<SparseTensorType>& left,
826-
const SparseTensorBase& right) {
824+
template <typename SparseIndexType>
825+
inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
826+
const SparseTensor& right) {
827827
switch (right.sparse_tensor_format_id()) {
828828
case SparseTensorFormat::COO: {
829-
const auto& right_coo = checked_cast<const SparseTensor<SparseCOOIndex>&>(right);
830-
return SparseTensorEqualsImpl<SparseTensorType, SparseCOOIndex>::Compare(left,
829+
const auto& right_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
830+
return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left,
831831
right_coo);
832832
}
833833

834834
case SparseTensorFormat::CSR: {
835-
const auto& right_csr = checked_cast<const SparseTensor<SparseCSRIndex>&>(right);
836-
return SparseTensorEqualsImpl<SparseTensorType, SparseCSRIndex>::Compare(left,
835+
const auto& right_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
836+
return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left,
837837
right_csr);
838838
}
839839

@@ -844,7 +844,7 @@ inline bool SparseTensorEqualsImplDispatch(const SparseTensor<SparseTensorType>&
844844

845845
} // namespace
846846

847-
bool SparseTensorEquals(const SparseTensorBase& left, const SparseTensorBase& right) {
847+
bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
848848
if (&left == &right) {
849849
return true;
850850
} else if (left.type()->id() != right.type()->id()) {
@@ -859,12 +859,12 @@ bool SparseTensorEquals(const SparseTensorBase& left, const SparseTensorBase& ri
859859

860860
switch (left.sparse_tensor_format_id()) {
861861
case SparseTensorFormat::COO: {
862-
const auto& left_coo = checked_cast<const SparseTensor<SparseCOOIndex>&>(left);
862+
const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
863863
return SparseTensorEqualsImplDispatch(left_coo, right);
864864
}
865865

866866
case SparseTensorFormat::CSR: {
867-
const auto& left_csr = checked_cast<const SparseTensor<SparseCSRIndex>&>(left);
867+
const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
868868
return SparseTensorEqualsImplDispatch(left_csr, right);
869869
}
870870

cpp/src/arrow/compare.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ namespace arrow {
2929
class Array;
3030
class DataType;
3131
class Tensor;
32-
class SparseTensorBase;
32+
class SparseTensor;
3333

3434
/// Returns true if the arrays are exactly equal
3535
bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right);
3636

3737
bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right);
3838

3939
/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal
40-
bool ARROW_EXPORT SparseTensorEquals(const SparseTensorBase& left,
41-
const SparseTensorBase& right);
40+
bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left,
41+
const SparseTensor& right);
4242

4343
/// Returns true if the arrays are approximately equal. For non-floating point
4444
/// types, this is equivalent to ArrayEquals(left, right)

cpp/src/arrow/ipc/metadata-internal.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ Status MakeSparseTensorIndex(FBB& fbb, const SparseIndex& sparse_index,
836836
return Status::OK();
837837
}
838838

839-
Status MakeSparseTensor(FBB& fbb, const SparseTensorBase& sparse_tensor,
839+
Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor,
840840
int64_t body_length, const std::vector<BufferMetadata>& buffers,
841841
SparseTensorOffset* offset) {
842842
flatbuf::Type fb_type_type;
@@ -872,7 +872,7 @@ Status MakeSparseTensor(FBB& fbb, const SparseTensorBase& sparse_tensor,
872872
return Status::OK();
873873
}
874874

875-
Status WriteSparseTensorMessage(const SparseTensorBase& sparse_tensor,
875+
Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor,
876876
int64_t body_length,
877877
const std::vector<BufferMetadata>& buffers,
878878
std::shared_ptr<Buffer>* out) {

cpp/src/arrow/ipc/metadata-internal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace arrow {
4141
class DataType;
4242
class Schema;
4343
class Tensor;
44-
class SparseTensorBase;
44+
class SparseTensor;
4545

4646
namespace flatbuf = org::apache::arrow::flatbuf;
4747

@@ -145,7 +145,7 @@ Status WriteRecordBatchMessage(const int64_t length, const int64_t body_length,
145145
Status WriteTensorMessage(const Tensor& tensor, const int64_t buffer_start_offset,
146146
std::shared_ptr<Buffer>* out);
147147

148-
Status WriteSparseTensorMessage(const SparseTensorBase& sparse_tensor,
148+
Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor,
149149
int64_t body_length,
150150
const std::vector<BufferMetadata>& buffers,
151151
std::shared_ptr<Buffer>* out);

cpp/src/arrow/ipc/read-write-test.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,14 +851,14 @@ class TestSparseTensorRoundTrip : public ::testing::Test, public IpcTestFixture
851851
void TearDown() { io::MemoryMapFixture::TearDown(); }
852852

853853
template <typename SparseIndexType>
854-
void CheckSparseTensorRoundTrip(const SparseTensor<SparseIndexType>& tensor) {
854+
void CheckSparseTensorRoundTrip(const SparseTensorImpl<SparseIndexType>& tensor) {
855855
GTEST_FAIL();
856856
}
857857
};
858858

859859
template <>
860860
void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCOOIndex>(
861-
const SparseTensor<SparseCOOIndex>& tensor) {
861+
const SparseTensorImpl<SparseCOOIndex>& tensor) {
862862
const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
863863
const int elem_size = type.bit_width() / 8;
864864

@@ -878,7 +878,7 @@ void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCOOIndex>(
878878

879879
ASSERT_OK(mmap_->Seek(0));
880880

881-
std::shared_ptr<SparseTensorBase> result;
881+
std::shared_ptr<SparseTensor> result;
882882
ASSERT_OK(ReadSparseTensor(mmap_.get(), &result));
883883

884884
const auto& resulted_sparse_index =
@@ -890,7 +890,7 @@ void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCOOIndex>(
890890

891891
template <>
892892
void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCSRIndex>(
893-
const SparseTensor<SparseCSRIndex>& tensor) {
893+
const SparseTensorImpl<SparseCSRIndex>& tensor) {
894894
const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
895895
const int elem_size = type.bit_width() / 8;
896896

@@ -911,7 +911,7 @@ void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCSRIndex>(
911911

912912
ASSERT_OK(mmap_->Seek(0));
913913

914-
std::shared_ptr<SparseTensorBase> result;
914+
std::shared_ptr<SparseTensor> result;
915915
ASSERT_OK(ReadSparseTensor(mmap_.get(), &result));
916916

917917
const auto& resulted_sparse_index =
@@ -934,7 +934,7 @@ TEST_F(TestSparseTensorRoundTrip, WithSparseCOOIndex) {
934934

935935
auto data = Buffer::Wrap(values);
936936
NumericTensor<Int64Type> t(data, shape, {}, dim_names);
937-
SparseTensor<SparseCOOIndex> st(t);
937+
SparseTensorImpl<SparseCOOIndex> st(t);
938938

939939
CheckSparseTensorRoundTrip(st);
940940
}
@@ -951,7 +951,7 @@ TEST_F(TestSparseTensorRoundTrip, WithSparseCSRIndex) {
951951

952952
auto data = Buffer::Wrap(values);
953953
NumericTensor<Int64Type> t(data, shape, {}, dim_names);
954-
SparseTensor<SparseCSRIndex> st(t);
954+
SparseTensorImpl<SparseCSRIndex> st(t);
955955

956956
CheckSparseTensorRoundTrip(st);
957957
}

cpp/src/arrow/ipc/reader.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -772,24 +772,24 @@ Status MakeSparseTensorWithSparseCOOIndex(
772772
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
773773
const std::vector<std::string>& dim_names,
774774
const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
775-
const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensorBase>* out) {
776-
*out = std::make_shared<SparseTensor<SparseCOOIndex>>(sparse_index, type, data, shape, dim_names);
775+
const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) {
776+
*out = std::make_shared<SparseTensorImpl<SparseCOOIndex>>(sparse_index, type, data, shape, dim_names);
777777
return Status::OK();
778778
}
779779

780780
Status MakeSparseTensorWithSparseCSRIndex(
781781
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
782782
const std::vector<std::string>& dim_names,
783783
const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
784-
const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensorBase>* out) {
785-
*out = std::make_shared<SparseTensor<SparseCSRIndex>>(sparse_index, type, data, shape, dim_names);
784+
const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) {
785+
*out = std::make_shared<SparseTensorImpl<SparseCSRIndex>>(sparse_index, type, data, shape, dim_names);
786786
return Status::OK();
787787
}
788788

789789
} // namespace
790790

791791
Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
792-
std::shared_ptr<SparseTensorBase>* out) {
792+
std::shared_ptr<SparseTensor>* out) {
793793
std::shared_ptr<DataType> type;
794794
std::vector<int64_t> shape;
795795
std::vector<std::string> dim_names;
@@ -830,12 +830,12 @@ Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
830830
}
831831
}
832832

833-
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensorBase>* out) {
833+
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out) {
834834
io::BufferReader buffer_reader(message.body());
835835
return ReadSparseTensor(*message.metadata(), &buffer_reader, out);
836836
}
837837

838-
Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensorBase>* out) {
838+
Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* out) {
839839
std::unique_ptr<Message> message;
840840
RETURN_NOT_OK(ReadContiguousPayload(file, &message));
841841
DCHECK_EQ(message->type(), Message::SPARSE_TENSOR);

cpp/src/arrow/ipc/reader.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Buffer;
3333
class Schema;
3434
class Status;
3535
class Tensor;
36-
class SparseTensorBase;
36+
class SparseTensor;
3737

3838
namespace io {
3939

@@ -242,15 +242,15 @@ Status ReadTensor(const Message& message, std::shared_ptr<Tensor>* out);
242242
/// \param[out] out the read sparse tensor
243243
/// \return Status
244244
ARROW_EXPORT
245-
Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensorBase>* out);
245+
Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* out);
246246

247247
/// \brief EXPERIMENTAL: Read arrow::SparseTensor from IPC message
248248
///
249249
/// \param[in] message a Message containing the tensor metadata and body
250250
/// \param[out] out the read sparse tensor
251251
/// \return Status
252252
ARROW_EXPORT
253-
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensorBase>* out);
253+
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out);
254254

255255
} // namespace ipc
256256
} // namespace arrow

cpp/src/arrow/ipc/writer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,12 +703,12 @@ class SparseTensorSerializer {
703703
return Status::OK();
704704
}
705705

706-
Status SerializeMetadata(const SparseTensorBase& sparse_tensor) {
706+
Status SerializeMetadata(const SparseTensor& sparse_tensor) {
707707
return WriteSparseTensorMessage(sparse_tensor, out_->body_length, buffer_meta_,
708708
&out_->metadata);
709709
}
710710

711-
Status Assemble(const SparseTensorBase& sparse_tensor) {
711+
Status Assemble(const SparseTensor& sparse_tensor) {
712712
if (buffer_meta_.size() > 0) {
713713
buffer_meta_.clear();
714714
out_->body_buffers.clear();
@@ -753,15 +753,15 @@ class SparseTensorSerializer {
753753
int64_t buffer_start_offset_;
754754
};
755755

756-
Status GetSparseTensorPayload(const SparseTensorBase& sparse_tensor, MemoryPool* pool,
756+
Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool,
757757
IpcPayload* out) {
758758
SparseTensorSerializer writer(0, out);
759759
return writer.Assemble(sparse_tensor);
760760
}
761761

762762
} // namespace internal
763763

764-
Status WriteSparseTensor(const SparseTensorBase& sparse_tensor, io::OutputStream* dst,
764+
Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
765765
int32_t* metadata_length, int64_t* body_length,
766766
MemoryPool* pool) {
767767
internal::IpcPayload payload;

cpp/src/arrow/ipc/writer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Schema;
3636
class Status;
3737
class Table;
3838
class Tensor;
39-
class SparseTensorBase;
39+
class SparseTensor;
4040

4141
namespace io {
4242

@@ -280,7 +280,7 @@ Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadat
280280
// \param[out] metadata_length the actual metadata length, including padding
281281
// \param[out] body_length the actual message body length
282282
ARROW_EXPORT
283-
Status WriteSparseTensor(const SparseTensorBase& sparse_tensor, io::OutputStream* dst,
283+
Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
284284
int32_t* metadata_length, int64_t* body_length,
285285
MemoryPool* pool);
286286

cpp/src/arrow/sparse_tensor-test.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@
3333
namespace arrow {
3434

3535
static inline void CheckSparseIndexFormatType(SparseTensorFormat::type expected,
36-
const SparseTensorBase& sparse_tensor) {
36+
const SparseTensor& sparse_tensor) {
3737
ASSERT_EQ(expected, sparse_tensor.sparse_tensor_format_id());
3838
ASSERT_EQ(expected, sparse_tensor.sparse_index()->format_id());
3939
}
4040

4141
TEST(TestSparseCOOTensor, CreationEmptyTensor) {
4242
std::vector<int64_t> shape = {2, 3, 4};
43-
SparseTensor<SparseCOOIndex> st1(int64(), shape);
43+
SparseTensorImpl<SparseCOOIndex> st1(int64(), shape);
4444

4545
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
46-
SparseTensor<SparseCOOIndex> st2(int64(), shape, dim_names);
46+
SparseTensorImpl<SparseCOOIndex> st2(int64(), shape, dim_names);
4747

4848
ASSERT_EQ(0, st1.non_zero_length());
4949
ASSERT_EQ(0, st2.non_zero_length());
@@ -68,8 +68,8 @@ TEST(TestSparseCOOTensor, CreationFromNumericTensor) {
6868
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
6969
NumericTensor<Int64Type> tensor1(buffer, shape);
7070
NumericTensor<Int64Type> tensor2(buffer, shape, {}, dim_names);
71-
SparseTensor<SparseCOOIndex> st1(tensor1);
72-
SparseTensor<SparseCOOIndex> st2(tensor2);
71+
SparseTensorImpl<SparseCOOIndex> st1(tensor1);
72+
SparseTensorImpl<SparseCOOIndex> st2(tensor2);
7373

7474
CheckSparseIndexFormatType(SparseTensorFormat::COO, st1);
7575

@@ -133,8 +133,8 @@ TEST(TestSparseCOOTensor, CreationFromTensor) {
133133
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
134134
Tensor tensor1(int64(), buffer, shape);
135135
Tensor tensor2(int64(), buffer, shape, {}, dim_names);
136-
SparseTensor<SparseCOOIndex> st1(tensor1);
137-
SparseTensor<SparseCOOIndex> st2(tensor2);
136+
SparseTensorImpl<SparseCOOIndex> st1(tensor1);
137+
SparseTensorImpl<SparseCOOIndex> st2(tensor2);
138138

139139
ASSERT_EQ(12, st1.non_zero_length());
140140
ASSERT_TRUE(st1.is_mutable());
@@ -195,8 +195,8 @@ TEST(TestSparseCSRMatrix, CreationFromNumericTensor2D) {
195195
NumericTensor<Int64Type> tensor1(buffer, shape);
196196
NumericTensor<Int64Type> tensor2(buffer, shape, {}, dim_names);
197197

198-
SparseTensor<SparseCSRIndex> st1(tensor1);
199-
SparseTensor<SparseCSRIndex> st2(tensor2);
198+
SparseTensorImpl<SparseCSRIndex> st1(tensor1);
199+
SparseTensorImpl<SparseCSRIndex> st2(tensor2);
200200

201201
CheckSparseIndexFormatType(SparseTensorFormat::CSR, st1);
202202

0 commit comments

Comments
 (0)