@@ -515,24 +515,25 @@ template <typename TYPE, typename IndexValueType>
515515void assign_values (int64_t dimension_index, int64_t offset, int64_t first_ptr,
516516 int64_t last_ptr, const SparseCSFIndex* sparse_index,
517517 const int64_t * raw_data, const std::vector<int64_t > strides,
518- TYPE* out) {
519- auto indices_offset = sparse_index->indices_offsets ()[dimension_index];
520- auto indptr_offset = sparse_index->indptr_offsets ()[dimension_index];
518+ const std::vector<int64_t > axis_order, TYPE* out) {
519+ auto dimension = axis_order[dimension_index];
520+ auto indices_offset = sparse_index->indices_offsets ()[dimension];
521+ auto indptr_offset = sparse_index->indptr_offsets ()[dimension];
521522 int64_t ndim = sparse_index->indices_offsets ().size ();
522523
523- if (dimension_index == 0 && ndim > 1 )
524- last_ptr = sparse_index->indptr_offsets ()[dimension_index + 1 ] - 1 ;
524+ if (dimension == 0 && ndim > 1 )
525+ last_ptr = sparse_index->indptr_offsets ()[dimension + 1 ] - 1 ;
525526
526527 for (int64_t i = first_ptr; i < last_ptr; ++i) {
527528 int64_t tmp_offset =
528529 offset + sparse_index->indices ()->Value <IndexValueType>({indices_offset + i}) *
529- strides[dimension_index ];
530+ strides[dimension ];
530531 if (dimension_index < ndim - 1 )
531532 assign_values<TYPE, IndexValueType>(
532- dimension_index + 1 , tmp_offset,
533+ dimension + 1 , tmp_offset,
533534 sparse_index->indptr ()->Value <IndexValueType>({indptr_offset + i}),
534535 sparse_index->indptr ()->Value <IndexValueType>({indptr_offset + i + 1 }),
535- sparse_index, raw_data, strides, out);
536+ sparse_index, raw_data, strides, axis_order, out);
536537 else
537538 out[tmp_offset] = static_cast <TYPE>(raw_data[i]);
538539 }
@@ -625,7 +626,8 @@ Status MakeTensorFromSparseTensor(MemoryPool* pool, const SparseTensor* sparse_t
625626 internal::checked_cast<const SparseCSFIndex&>(*sparse_tensor->sparse_index ());
626627 assign_values<value_type, IndexValueType>(
627628 0 , 0 , 0 , 0 , &sparse_index,
628- reinterpret_cast <const int64_t *>(sparse_tensor->raw_data ()), strides, values);
629+ reinterpret_cast <const int64_t *>(sparse_tensor->raw_data ()), strides,
630+ sparse_index.axis_order (), values);
629631 *out = std::make_shared<Tensor>(sparse_tensor->type (), values_buffer,
630632 sparse_tensor->shape ());
631633 return Status::OK ();
0 commit comments