|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | #pragma once |
19 | | -#include "arrow/extension/tensor_internal.h" |
20 | 19 |
|
21 | 20 | #include <cstdint> |
22 | 21 | #include <vector> |
23 | 22 |
|
24 | 23 | #include "arrow/array/array_nested.h" |
25 | | -#include "arrow/tensor.h" |
26 | | -#include "arrow/status.h" |
27 | | -#include "arrow/util/checked_cast.h" |
28 | | -#include "arrow/util/int_util_overflow.h" |
29 | | -#include "arrow/util/sort_internal.h" |
30 | | -#include "arrow/util/print_internal.h" |
31 | 24 |
|
32 | 25 | namespace arrow::internal { |
33 | 26 |
|
34 | 27 | ARROW_EXPORT |
35 | | -inline Status IsPermutationValid(const std::vector<int64_t>& permutation) { |
36 | | - const auto size = static_cast<int64_t>(permutation.size()); |
37 | | - std::vector<uint8_t> dim_seen(size, 0); |
38 | | - |
39 | | - for (const auto p : permutation) { |
40 | | - if (p < 0 || p >= size || dim_seen[p] != 0) { |
41 | | - return Status::Invalid( |
42 | | - "Permutation indices for ", size, |
43 | | - " dimensional tensors must be unique and within [0, ", size - 1, |
44 | | - "] range. Got: ", ::arrow::internal::PrintVector{permutation, ","}); |
45 | | - } |
46 | | - dim_seen[p] = 1; |
47 | | - } |
48 | | - return Status::OK(); |
49 | | -} |
| 28 | +Status IsPermutationValid(const std::vector<int64_t>& permutation); |
50 | 29 |
|
51 | 30 | ARROW_EXPORT |
52 | | -inline Status ComputeStrides(const std::shared_ptr<DataType>& value_type, |
| 31 | +Status ComputeStrides(const std::shared_ptr<DataType>& value_type, |
53 | 32 | const std::vector<int64_t>& shape, |
54 | 33 | const std::vector<int64_t>& permutation, |
55 | | - std::vector<int64_t>* strides) { |
56 | | - auto fixed_width_type = internal::checked_pointer_cast<FixedWidthType>(value_type); |
57 | | - if (permutation.empty()) { |
58 | | - return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides); |
59 | | - } |
60 | | - const int byte_width = value_type->byte_width(); |
61 | | - |
62 | | - int64_t remaining = 0; |
63 | | - if (!shape.empty() && shape.front() > 0) { |
64 | | - remaining = byte_width; |
65 | | - for (auto i : permutation) { |
66 | | - if (i > 0) { |
67 | | - if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { |
68 | | - return Status::Invalid( |
69 | | - "Strides computed from shape would not fit in 64-bit integer"); |
70 | | - } |
71 | | - } |
72 | | - } |
73 | | - } |
74 | | - |
75 | | - if (remaining == 0) { |
76 | | - strides->assign(shape.size(), byte_width); |
77 | | - return Status::OK(); |
78 | | - } |
79 | | - |
80 | | - strides->push_back(remaining); |
81 | | - for (auto i : permutation) { |
82 | | - if (i > 0) { |
83 | | - remaining /= shape[i]; |
84 | | - strides->push_back(remaining); |
85 | | - } |
86 | | - } |
87 | | - Permute(permutation, strides); |
88 | | - |
89 | | - return Status::OK(); |
90 | | -} |
| 34 | + std::vector<int64_t>* strides); |
91 | 35 |
|
92 | 36 | } // namespace arrow::internal |
0 commit comments