Skip to content

Commit ed3984d

Browse files
committed
Add SparseIndex::format_type
1 parent 4251b4d commit ed3984d

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

cpp/src/arrow/sparse_tensor.cc

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,6 @@ namespace arrow {
2727

2828
namespace {
2929

30-
template <typename T>
31-
struct SparseIndexTraits {};
32-
33-
template <>
34-
struct SparseIndexTraits<SparseCOOIndex> {
35-
static inline const char* name() { return "SparseCOOIndex"; }
36-
};
37-
38-
template <>
39-
struct SparseIndexTraits<SparseCSRIndex> {
40-
static inline const char* name() { return "SparseCSRIndex"; }
41-
};
42-
4330
// ----------------------------------------------------------------------
4431
// SparseTensorConverter
4532

@@ -49,9 +36,7 @@ class SparseTensorConverter {
4936
explicit SparseTensorConverter(const NumericTensor<TYPE>&) {}
5037

5138
Status Convert() {
52-
std::string sparse_index_name(SparseIndexTraits<SparseIndexType>::name());
53-
return Status::NotImplemented(sparse_index_name +
54-
std::string(" is not supported yet."));
39+
return Status::Invalid("Unsupported sparse index");
5540
}
5641
};
5742

@@ -303,7 +288,7 @@ INSTANTIATE_SPARSE_TENSOR_CONVERTER(SparseCSRIndex);
303288

304289
// Constructor with a column-major NumericTensor
305290
SparseCOOIndex::SparseCOOIndex(const std::shared_ptr<CoordsTensor>& coords)
306-
: SparseIndex(coords->shape()[0]), coords_(coords) {
291+
: SparseIndex(SparseIndex::COO, coords->shape()[0]), coords_(coords) {
307292
DCHECK(coords_->is_column_major());
308293
}
309294

@@ -313,7 +298,7 @@ SparseCOOIndex::SparseCOOIndex(const std::shared_ptr<CoordsTensor>& coords)
313298
// Constructor with two index vectors
314299
SparseCSRIndex::SparseCSRIndex(const std::shared_ptr<IndexTensor>& indptr,
315300
const std::shared_ptr<IndexTensor>& indices)
316-
: SparseIndex(indices->shape()[0]), indptr_(indptr), indices_(indices) {
301+
: SparseIndex(SparseIndex::CSR, indices->shape()[0]), indptr_(indptr), indices_(indices) {
317302
DCHECK_EQ(1, indptr_->ndim());
318303
DCHECK_EQ(1, indices_->ndim());
319304
}

cpp/src/arrow/sparse_tensor.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,19 @@ namespace arrow {
3131

3232
class ARROW_EXPORT SparseIndex {
3333
public:
34-
explicit SparseIndex(int64_t length) : length_(length) {}
34+
enum format_type {
35+
COO,
36+
CSR
37+
};
38+
39+
explicit SparseIndex(format_type format_type_id, int64_t length)
40+
: format_type_id_(format_type_id), length_(length) {}
41+
42+
format_type format_type_id() const { return format_type_id_; }
3543
int64_t length() const { return length_; }
3644

3745
protected:
46+
format_type format_type_id_;
3847
int64_t length_;
3948
};
4049

@@ -45,6 +54,8 @@ class ARROW_EXPORT SparseCOOIndex : public SparseIndex {
4554
public:
4655
using CoordsTensor = NumericTensor<Int64Type>;
4756

57+
static constexpr SparseIndex::format_type format_type_id = SparseIndex::COO;
58+
4859
virtual ~SparseCOOIndex() = default;
4960

5061
// Constructor with a column-major NumericTensor
@@ -63,6 +74,8 @@ class ARROW_EXPORT SparseCSRIndex : public SparseIndex {
6374
public:
6475
using IndexTensor = NumericTensor<Int64Type>;
6576

77+
static constexpr SparseIndex::format_type format_type_id = SparseIndex::COO;
78+
6679
virtual ~SparseCSRIndex() = default;
6780

6881
// Constructor with two index vectors

0 commit comments

Comments
 (0)