Skip to content

Commit 4cfef06

Browse files
mrknpitrou
authored andcommitted
Add a new test of csr sparse matrix creation from non-contiguous tensor
1 parent d9f32f1 commit 4cfef06

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

cpp/src/arrow/sparse_tensor-test.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,45 @@ TEST(TestSparseCSRMatrix, CreationFromNumericTensor2D) {
293293
ASSERT_EQ(std::vector<int64_t>({0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}), indices_values);
294294
}
295295

296+
TEST(TestSparseCSRMatrix, CreationFromNonContiguousTensor) {
297+
std::vector<int64_t> shape = {6, 4};
298+
std::vector<int64_t> values = {1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0,
299+
5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 11, 0, 0, 0, 12, 0,
300+
13, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 0, 0, 16, 0};
301+
std::vector<int64_t> strides = {64, 16};
302+
std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
303+
Tensor tensor(int64(), buffer, shape, strides);
304+
SparseTensorImpl<SparseCSRIndex> st(tensor);
305+
306+
ASSERT_EQ(12, st.non_zero_length());
307+
ASSERT_TRUE(st.is_mutable());
308+
309+
const int64_t* ptr = reinterpret_cast<const int64_t*>(st.raw_data());
310+
for (int i = 0; i < 6; ++i) {
311+
ASSERT_EQ(i + 1, ptr[i]);
312+
}
313+
for (int i = 0; i < 6; ++i) {
314+
ASSERT_EQ(i + 11, ptr[i + 6]);
315+
}
316+
317+
const auto& si = internal::checked_cast<const SparseCSRIndex&>(*st.sparse_index());
318+
ASSERT_EQ(1, si.indptr()->ndim());
319+
ASSERT_EQ(1, si.indices()->ndim());
320+
321+
const int64_t* indptr_begin = reinterpret_cast<const int64_t*>(si.indptr()->raw_data());
322+
std::vector<int64_t> indptr_values(indptr_begin,
323+
indptr_begin + si.indptr()->shape()[0]);
324+
325+
ASSERT_EQ(7, indptr_values.size());
326+
ASSERT_EQ(std::vector<int64_t>({0, 2, 4, 6, 8, 10, 12}), indptr_values);
327+
328+
const int64_t* indices_begin =
329+
reinterpret_cast<const int64_t*>(si.indices()->raw_data());
330+
std::vector<int64_t> indices_values(indices_begin,
331+
indices_begin + si.indices()->shape()[0]);
332+
333+
ASSERT_EQ(12, indices_values.size());
334+
ASSERT_EQ(std::vector<int64_t>({0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}), indices_values);
335+
}
336+
296337
} // namespace arrow

0 commit comments

Comments
 (0)