@@ -60,7 +60,8 @@ class SparseTensorConverter {
6060// ----------------------------------------------------------------------
6161// IncrementIndex for SparseCOOIndex and SparseCSFIndex
6262
63- void IncrementIndex (std::vector<int64_t >& coord, const std::vector<int64_t > shape) {
63+ inline void IncrementIndex (std::vector<int64_t >& coord,
64+ const std::vector<int64_t >& shape) {
6465 const int64_t ndim = shape.size ();
6566 ++coord[ndim - 1 ];
6667 if (coord[ndim - 1 ] == shape[ndim - 1 ]) {
@@ -73,8 +74,8 @@ void IncrementIndex(std::vector<int64_t>& coord, const std::vector<int64_t> shap
7374 }
7475}
7576
76- void IncrementIndex (std::vector<int64_t >& coord, const std::vector<int64_t > shape,
77- std::vector<int64_t > axis_order) {
77+ inline void IncrementIndex (std::vector<int64_t >& coord, const std::vector<int64_t >& shape,
78+ const std::vector<int64_t >& axis_order) {
7879 const int64_t ndim = shape.size ();
7980 const int64_t last_axis = axis_order[ndim - 1 ];
8081 ++coord[last_axis];
@@ -161,9 +162,7 @@ class SparseTensorConverter<TYPE, SparseCOOIndex>
161162 *indices++ = static_cast <c_index_value_type>(coord[i]);
162163 }
163164 }
164- if (n > 1 ) {
165- IncrementIndex (coord, shape);
166- }
165+ IncrementIndex (coord, shape);
167166 }
168167 }
169168
@@ -496,11 +495,9 @@ class SparseTensorConverter<TYPE, SparseCSFIndex>
496495
497496 for (int64_t i = 0 ; i < ndim; ++i) {
498497 int64_t dimension = axis_order[i];
499- bool change = coord[dimension] != previous_coord[dimension];
500-
501- if (tree_split || change) {
502- if (change) tree_split = true ;
503498
499+ tree_split = tree_split || (coord[dimension] != previous_coord[dimension]);
500+ if (tree_split) {
504501 if (i < ndim - 1 ) {
505502 RETURN_NOT_OK (indptr_buffer_builders[i].Append (
506503 static_cast <c_index_value_type>(counts[i + 1 ])));
@@ -512,9 +509,7 @@ class SparseTensorConverter<TYPE, SparseCSFIndex>
512509 }
513510 previous_coord = coord;
514511 }
515- if (n > 1 ) {
516- IncrementIndex (coord, shape, axis_order);
517- }
512+ IncrementIndex (coord, shape, axis_order);
518513 }
519514 }
520515
@@ -682,25 +677,26 @@ Status MakeSparseTensorFromTensor(const Tensor& tensor,
682677namespace {
683678
684679template <typename TYPE, typename IndexValueType>
685- void ExpandSparseCSFTensorValues (int64_t dimension, int64_t offset, int64_t first_ptr,
686- int64_t last_ptr, const SparseCSFIndex* sparse_index,
687- const TYPE* raw_data, const std::vector<int64_t > strides,
688- const std::vector<int64_t > axis_order, TYPE* out) {
680+ void ExpandSparseCSFTensorValues (int64_t dimension, int64_t dense_offset,
681+ int64_t first_ptr, int64_t last_ptr,
682+ const SparseCSFIndex& sparse_index, const TYPE* raw_data,
683+ const std::vector<int64_t >& strides,
684+ const std::vector<int64_t >& axis_order, TYPE* out) {
689685 int64_t ndim = axis_order.size ();
690686
691687 for (int64_t i = first_ptr; i < last_ptr; ++i) {
692- int64_t tmp_offset =
693- offset + sparse_index-> indices ()[dimension]->Value <IndexValueType>({i}) *
694- strides[axis_order[dimension]];
688+ int64_t tmp_dense_offset =
689+ dense_offset + sparse_index. indices ()[dimension]->Value <IndexValueType>({i}) *
690+ strides[axis_order[dimension]];
695691
696692 if (dimension < ndim - 1 ) {
697693 ExpandSparseCSFTensorValues<TYPE, IndexValueType>(
698- dimension + 1 , tmp_offset ,
699- sparse_index-> indptr ()[dimension]->Value <IndexValueType>({i}),
700- sparse_index-> indptr ()[dimension]->Value <IndexValueType>({i + 1 }), sparse_index,
694+ dimension + 1 , tmp_dense_offset ,
695+ sparse_index. indptr ()[dimension]->Value <IndexValueType>({i}),
696+ sparse_index. indptr ()[dimension]->Value <IndexValueType>({i + 1 }), sparse_index,
701697 raw_data, strides, axis_order, out);
702698 } else {
703- out[tmp_offset ] = raw_data[i];
699+ out[tmp_dense_offset ] = raw_data[i];
704700 }
705701 }
706702}
@@ -797,7 +793,7 @@ Status MakeTensorFromSparseTensor(MemoryPool* pool, const SparseTensor* sparse_t
797793 internal::checked_cast<const SparseCSFIndex&>(*sparse_tensor->sparse_index ());
798794
799795 ExpandSparseCSFTensorValues<value_type, IndexValueType>(
800- 0 , 0 , 0 , sparse_index.indptr ()[0 ]->size () - 1 , & sparse_index, raw_data, strides,
796+ 0 , 0 , 0 , sparse_index.indptr ()[0 ]->size () - 1 , sparse_index, raw_data, strides,
801797 sparse_index.axis_order (), values);
802798 *out = std::make_shared<Tensor>(sparse_tensor->type (), values_buffer,
803799 sparse_tensor->shape (), empty_strides,
@@ -995,11 +991,11 @@ inline Status CheckSparseCSFIndexValidity(const std::shared_ptr<DataType>& indpt
995991 }
996992 if (num_indptrs + 1 != num_indices) {
997993 return Status::Invalid (
998- " Length of indices must be equal to length of inptrs + 1 for SparseCSFIndex." );
994+ " Length of indices must be equal to length of indptrs + 1 for SparseCSFIndex." );
999995 }
1000996 if (axis_order_size != num_indices) {
1001997 return Status::Invalid (
1002- " Length of indices must be equal number of dimensions for SparseCSFIndex." );
998+ " Length of indices must be equal to number of dimensions for SparseCSFIndex." );
1003999 }
10041000 return Status::OK ();
10051001}
@@ -1045,6 +1041,16 @@ SparseCSFIndex::SparseCSFIndex(std::vector<std::shared_ptr<Tensor>>& indptr,
10451041
10461042std::string SparseCSFIndex::ToString () const { return std::string (" SparseCSFIndex" ); }
10471043
1044+ bool SparseCSFIndex::Equals (const SparseCSFIndex& other) const {
1045+ for (int64_t i = 0 ; i < static_cast <int64_t >(indices ().size ()); ++i) {
1046+ if (!indices ()[i]->Equals (*other.indices ()[i])) return false ;
1047+ }
1048+ for (int64_t i = 0 ; i < static_cast <int64_t >(indptr ().size ()); ++i) {
1049+ if (!indptr ()[i]->Equals (*other.indptr ()[i])) return false ;
1050+ }
1051+ return axis_order () == other.axis_order ();
1052+ }
1053+
10481054// ----------------------------------------------------------------------
10491055// SparseTensor
10501056
0 commit comments