Skip to content

Commit e89edc6

Browse files
rokpitrou
authored andcommitted
Refactoring to_numpy methods.
1 parent 3fcc192 commit e89edc6

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

cpp/src/arrow/python/numpy_convert.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,14 @@ Status SparseTensorCOOToNdarray(const std::shared_ptr<SparseTensorCOO>& sparse_t
293293
PyArray_Descr* dtype_coords = PyArray_DescrNewFromType(type_num_coords);
294294
RETURN_IF_PYERROR();
295295

296+
const int ndim_coords = sparse_tensor->ndim();
297+
std::vector<npy_intp> npy_shape_coords(ndim_coords);
298+
299+
for (int i = 0; i < ndim_coords; ++i) {
300+
npy_shape_coords[i] = sparse_index_coords->shape()[i];
301+
}
302+
296303
std::vector<npy_intp> npy_shape_data({sparse_index.non_zero_length(), 1});
297-
std::vector<npy_intp> npy_shape_coords({sparse_index_coords->shape()[0], 2});
298304

299305
const void* immutable_data = nullptr;
300306
if (sparse_tensor->data()) {
@@ -355,8 +361,20 @@ Status SparseTensorCSRToNdarray(const std::shared_ptr<SparseTensorCSR>& sparse_t
355361
sparse_index.indices();
356362

357363
std::vector<npy_intp> npy_shape_data({sparse_index.non_zero_length(), 1});
358-
std::vector<npy_intp> npy_shape_indptr({sparse_index_indptr->shape()[0], 1});
359-
std::vector<npy_intp> npy_shape_indices({sparse_index_indices->shape()[0], 1});
364+
365+
const int ndim_indptr = sparse_index_indptr->ndim();
366+
std::vector<npy_intp> npy_shape_indptr(ndim_indptr);
367+
368+
for (int i = 0; i < ndim_indptr; ++i) {
369+
npy_shape_indptr[i] = sparse_index_indptr->shape()[i];
370+
}
371+
372+
const int ndim_indices = sparse_index_indices->ndim();
373+
std::vector<npy_intp> npy_shape_indices(ndim_indices);
374+
375+
for (int i = 0; i < ndim_indices; ++i) {
376+
npy_shape_indices[i] = sparse_index_indices->shape()[i];
377+
}
360378

361379
const void* immutable_data = nullptr;
362380
if (sparse_tensor->data()) {

python/pyarrow/tests/test_sparse_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ def ne(a, b):
7777

7878
data = np.random.randn(10, 6)[::, ::2]
7979
sparse_tensor1 = sparse_tensor_type.from_dense_numpy(data)
80-
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(np.ascontiguousarray(data))
80+
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
81+
np.ascontiguousarray(data))
8182
eq(sparse_tensor1, sparse_tensor2)
8283
data = data.copy()
8384
data[9, 0] = 1.0
84-
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(np.ascontiguousarray(data))
85+
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
86+
np.ascontiguousarray(data))
8587
ne(sparse_tensor1, sparse_tensor2)
8688

8789

0 commit comments

Comments
 (0)