Skip to content
This repository was archived by the owner on Oct 21, 2024. It is now read-only.

Commit a322ff5

Browse files
committed
Adding tests for multiple index value types for SparseCSFIndex.
1 parent f44d92c commit a322ff5

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

cpp/src/arrow/sparse_tensor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ class SparseTensorConverter<TYPE, SparseCSFIndex>
477477
if (row > 1) tree_split = true;
478478

479479
indices[column * nonzero_count + counts[column]] =
480-
coords->Value<IndexValueType>({row, column});
480+
static_cast<c_index_value_type>(
481+
coords->Value<IndexValueType>({row, column}));
481482
indptr[column * (nonzero_count + 1) + counts[column]] =
482483
static_cast<c_index_value_type>(counts[column + 1]);
483484
++counts[column];

cpp/src/arrow/sparse_tensor_test.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,4 +1014,110 @@ TEST_F(TestSparseCSFTensor, CreationFromTensor) {
10141014
ASSERT_OK(st->ToTensor(&dt));
10151015
ASSERT_TRUE(tensor.Equals(*dt));
10161016
}
1017+
1018+
template <typename IndexValueType>
1019+
class TestSparseCSFTensorForIndexValueType
1020+
: public TestSparseCSFTensorBase<IndexValueType> {
1021+
protected:
1022+
std::shared_ptr<SparseCSFIndex> MakeSparseCSFIndex(
1023+
std::vector<typename IndexValueType::c_type>& indptr_values,
1024+
std::vector<typename IndexValueType::c_type>& indices_values,
1025+
const std::vector<int64_t>& indptr_offsets,
1026+
const std::vector<int64_t>& indices_offsets,
1027+
const std::vector<int64_t>& indptr_shape, const std::vector<int64_t>& indices_shape,
1028+
const std::vector<int64_t>& axis_order) const {
1029+
auto indptr_data = Buffer::Wrap(indptr_values);
1030+
auto indices_data = Buffer::Wrap(indices_values);
1031+
auto indptr =
1032+
std::make_shared<NumericTensor<IndexValueType>>(indptr_data, indptr_shape);
1033+
auto indices =
1034+
std::make_shared<NumericTensor<IndexValueType>>(indices_data, indices_shape);
1035+
return std::make_shared<SparseCSFIndex>(indptr, indices, indptr_offsets,
1036+
indices_offsets, axis_order);
1037+
}
1038+
1039+
template <typename CValueType>
1040+
std::shared_ptr<SparseCSFTensor> MakeSparseTensor(
1041+
const std::shared_ptr<SparseCSFIndex>& si,
1042+
std::vector<CValueType>& sparse_values) const {
1043+
auto data = Buffer::Wrap(sparse_values);
1044+
return std::make_shared<SparseCSFTensor>(si,
1045+
CTypeTraits<CValueType>::type_singleton(),
1046+
data, this->shape_, this->dim_names_);
1047+
}
1048+
};
1049+
1050+
TYPED_TEST_CASE_P(TestSparseCSFTensorForIndexValueType);
1051+
1052+
TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, ToTensor) {
1053+
using IndexValueType = TypeParam;
1054+
using c_index_value_type = typename IndexValueType::c_type;
1055+
1056+
std::vector<int64_t> data_values = {1, 2, 3, 4, 5, 6, 7, 8};
1057+
std::vector<c_index_value_type> indptr_values = {0, 2, 3, 0, 1, 3, 4, 0, 2, 4, 5, 8};
1058+
std::vector<c_index_value_type> indices_values = {1, 2, 1, 2, 2, 1, 1, 2, 2,
1059+
2, 3, 1, 3, 1, 1, 2, 3};
1060+
std::vector<int64_t> indices_offsets = {0, 2, 5, 9};
1061+
std::vector<int64_t> indptr_offsets = {0, 3, 7};
1062+
std::vector<int64_t> axis_order = {0, 1, 2, 3};
1063+
std::vector<int64_t> sparse_tensor_shape({3, 3, 3, 4});
1064+
std::vector<int64_t> indptr_shape({12});
1065+
std::vector<int64_t> indices_shape({17});
1066+
std::vector<std::string> dim_names({"a", "b", "c", "d"});
1067+
1068+
std::vector<int64_t> dense_values = {
1069+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1070+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1071+
1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1072+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 8};
1073+
1074+
std::shared_ptr<Buffer> data_buffer = Buffer::Wrap(data_values);
1075+
std::shared_ptr<Buffer> indptr_buffer = Buffer::Wrap(indptr_values);
1076+
std::shared_ptr<Buffer> indices_buffer = Buffer::Wrap(indices_values);
1077+
std::shared_ptr<Buffer> dense_data = Buffer::Wrap(dense_values);
1078+
1079+
std::shared_ptr<SparseCSFIndex> si =
1080+
this->MakeSparseCSFIndex(indptr_values, indices_values, indptr_offsets,
1081+
indices_offsets, indptr_shape, indices_shape, axis_order);
1082+
std::shared_ptr<SparseCSFTensor> st = this->MakeSparseTensor(si, data_values);
1083+
1084+
ASSERT_EQ(8, st->non_zero_length());
1085+
1086+
std::shared_ptr<Tensor> dt;
1087+
ASSERT_OK(st->ToTensor(&dt));
1088+
Tensor tensor(int64(), dense_data, sparse_tensor_shape, {});
1089+
ASSERT_TRUE(tensor.Equals(*dt));
1090+
1091+
std::shared_ptr<SparseCSFIndex> si2 =
1092+
arrow::internal::checked_pointer_cast<SparseCSFIndex>(
1093+
this->sparse_tensor_from_dense_->sparse_index());
1094+
1095+
ASSERT_EQ(si->indices()->type(), si2->indices()->type());
1096+
ASSERT_TRUE(si->indptr()->Equals(*si2->indptr()));
1097+
ASSERT_TRUE(si->indices()->Equals(*si2->indices()));
1098+
ASSERT_TRUE(si->indptr_offsets() == si2->indptr_offsets());
1099+
ASSERT_TRUE(si->indices_offsets() == si2->indices_offsets());
1100+
ASSERT_TRUE(si->indices_offsets() == si2->indices_offsets());
1101+
ASSERT_TRUE(si->axis_order() == si2->axis_order());
1102+
1103+
ASSERT_TRUE(si->Equals(*si2));
1104+
ASSERT_TRUE(st->data()->Equals(*this->sparse_tensor_from_dense_->data()));
1105+
// ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
1106+
}
1107+
1108+
REGISTER_TYPED_TEST_CASE_P(TestSparseCSFTensorForIndexValueType, ToTensor);
1109+
1110+
INSTANTIATE_TYPED_TEST_CASE_P(TestInt8, TestSparseCSFTensorForIndexValueType, Int8Type);
1111+
INSTANTIATE_TYPED_TEST_CASE_P(TestUInt8, TestSparseCSFTensorForIndexValueType, UInt8Type);
1112+
// INSTANTIATE_TYPED_TEST_CASE_P(TestInt16, TestSparseCSFTensorForIndexValueType,
1113+
// Int16Type); INSTANTIATE_TYPED_TEST_CASE_P(TestUInt16,
1114+
// TestSparseCSFTensorForIndexValueType,UInt16Type);
1115+
// INSTANTIATE_TYPED_TEST_CASE_P(TestInt32, TestSparseCSFTensorForIndexValueType,
1116+
// Int32Type);
1117+
INSTANTIATE_TYPED_TEST_CASE_P(TestUInt32, TestSparseCSFTensorForIndexValueType,
1118+
UInt32Type);
1119+
INSTANTIATE_TYPED_TEST_CASE_P(TestInt64, TestSparseCSFTensorForIndexValueType, Int64Type);
1120+
INSTANTIATE_TYPED_TEST_CASE_P(TestUInt64, TestSparseCSFTensorForIndexValueType,
1121+
UInt64Type);
1122+
10171123
} // namespace arrow

0 commit comments

Comments
 (0)