1616// under the License.
1717
1818#pragma once
19- #include " arrow/extension/tensor_internal.h"
2019
2120#include < cstdint>
2221#include < vector>
2322
2423#include " arrow/array/array_nested.h"
25- #include " arrow/tensor.h"
26- #include " arrow/status.h"
27- #include " arrow/util/checked_cast.h"
28- #include " arrow/util/int_util_overflow.h"
29- #include " arrow/util/sort_internal.h"
30- #include " arrow/util/print_internal.h"
3124
3225namespace arrow ::internal {
3326
3427ARROW_EXPORT
35- inline Status IsPermutationValid (const std::vector<int64_t >& permutation) {
36- const auto size = static_cast <int64_t >(permutation.size ());
37- std::vector<uint8_t > dim_seen (size, 0 );
38-
39- for (const auto p : permutation) {
40- if (p < 0 || p >= size || dim_seen[p] != 0 ) {
41- return Status::Invalid (
42- " Permutation indices for " , size,
43- " dimensional tensors must be unique and within [0, " , size - 1 ,
44- " ] range. Got: " , ::arrow::internal::PrintVector{permutation, " ," });
45- }
46- dim_seen[p] = 1 ;
47- }
48- return Status::OK ();
49- }
28+ Status IsPermutationValid (const std::vector<int64_t >& permutation);
5029
5130ARROW_EXPORT
52- inline Status ComputeStrides (const std::shared_ptr<DataType>& value_type,
31+ Status ComputeStrides (const std::shared_ptr<DataType>& value_type,
5332 const std::vector<int64_t >& shape,
5433 const std::vector<int64_t >& permutation,
55- std::vector<int64_t >* strides) {
56- auto fixed_width_type = internal::checked_pointer_cast<FixedWidthType>(value_type);
57- if (permutation.empty ()) {
58- return internal::ComputeRowMajorStrides (*fixed_width_type.get (), shape, strides);
59- }
60- const int byte_width = value_type->byte_width ();
61-
62- int64_t remaining = 0 ;
63- if (!shape.empty () && shape.front () > 0 ) {
64- remaining = byte_width;
65- for (auto i : permutation) {
66- if (i > 0 ) {
67- if (internal::MultiplyWithOverflow (remaining, shape[i], &remaining)) {
68- return Status::Invalid (
69- " Strides computed from shape would not fit in 64-bit integer" );
70- }
71- }
72- }
73- }
74-
75- if (remaining == 0 ) {
76- strides->assign (shape.size (), byte_width);
77- return Status::OK ();
78- }
79-
80- strides->push_back (remaining);
81- for (auto i : permutation) {
82- if (i > 0 ) {
83- remaining /= shape[i];
84- strides->push_back (remaining);
85- }
86- }
87- Permute (permutation, strides);
88-
89- return Status::OK ();
90- }
34+ std::vector<int64_t >* strides);
9135
92- } // namespace arrow::internal
36+ } // namespace arrow::internal
0 commit comments