@@ -57,6 +57,37 @@ class SparseTensorConverter {
5757 Status Convert () { return Status::Invalid (" Unsupported sparse index" ); }
5858};
5959
60+ // ----------------------------------------------------------------------
61+ // IncrementIndex for SparseCOOIndex and SparseCSFIndex
62+
63+ void IncrementIndex (std::vector<int64_t >& coord, const std::vector<int64_t > shape) {
64+ const int64_t ndim = shape.size ();
65+ ++coord[ndim - 1 ];
66+ if (coord[ndim - 1 ] == shape[ndim - 1 ]) {
67+ int64_t d = ndim - 1 ;
68+ while (d > 0 && coord[d] == shape[d]) {
69+ coord[d] = 0 ;
70+ ++coord[d - 1 ];
71+ --d;
72+ }
73+ }
74+ }
75+
76+ void IncrementIndex (std::vector<int64_t >& coord, const std::vector<int64_t > shape,
77+ std::vector<int64_t > axis_order) {
78+ const int64_t ndim = shape.size ();
79+ const int64_t last_axis = axis_order[ndim - 1 ];
80+ ++coord[last_axis];
81+ if (coord[last_axis] == shape[last_axis]) {
82+ int64_t d = ndim - 1 ;
83+ while (d > 0 && coord[axis_order[d]] == shape[axis_order[d]]) {
84+ coord[axis_order[d]] = 0 ;
85+ ++coord[axis_order[d - 1 ]];
86+ --d;
87+ }
88+ }
89+ }
90+
6091// ----------------------------------------------------------------------
6192// SparseTensorConverter for SparseCOOIndex
6293
@@ -130,15 +161,8 @@ class SparseTensorConverter<TYPE, SparseCOOIndex>
130161 *indices++ = static_cast <c_index_value_type>(coord[i]);
131162 }
132163 }
133- // increment index
134- ++coord[ndim - 1 ];
135- if (n > 1 && coord[ndim - 1 ] == shape[ndim - 1 ]) {
136- int64_t d = ndim - 1 ;
137- while (d > 0 && coord[d] == shape[d]) {
138- coord[d] = 0 ;
139- ++coord[d - 1 ];
140- --d;
141- }
164+ if (n > 1 ) {
165+ IncrementIndex (coord, shape);
142166 }
143167 }
144168 }
@@ -488,16 +512,8 @@ class SparseTensorConverter<TYPE, SparseCSFIndex>
488512 }
489513 previous_coord = coord;
490514 }
491- // increment index
492- int64_t last_axis = axis_order[ndim - 1 ];
493- ++coord[last_axis];
494- if (n > 1 && coord[last_axis] == shape[last_axis]) {
495- int64_t d = ndim - 1 ;
496- while (d > 0 && coord[axis_order[d]] == shape[axis_order[d]]) {
497- coord[axis_order[d]] = 0 ;
498- ++coord[axis_order[d - 1 ]];
499- --d;
500- }
515+ if (n > 1 ) {
516+ IncrementIndex (coord, shape, axis_order);
501517 }
502518 }
503519 }
0 commit comments