3737
3838namespace rj = arrow::rapidjson;
3939
40- namespace arrow {
41-
42- namespace extension {
43-
44- namespace {
45-
46- Status ComputeStrides (const FixedWidthType& type, const std::vector<int64_t >& shape,
47- const std::vector<int64_t >& permutation,
48- std::vector<int64_t >* strides) {
49- if (permutation.empty ()) {
50- return internal::ComputeRowMajorStrides (type, shape, strides);
51- }
52-
53- const int byte_width = type.byte_width ();
54-
55- int64_t remaining = 0 ;
56- if (!shape.empty () && shape.front () > 0 ) {
57- remaining = byte_width;
58- for (auto i : permutation) {
59- if (i > 0 ) {
60- if (internal::MultiplyWithOverflow (remaining, shape[i], &remaining)) {
61- return Status::Invalid (
62- " Strides computed from shape would not fit in 64-bit integer" );
63- }
64- }
65- }
66- }
67-
68- if (remaining == 0 ) {
69- strides->assign (shape.size (), byte_width);
70- return Status::OK ();
71- }
72-
73- strides->push_back (remaining);
74- for (auto i : permutation) {
75- if (i > 0 ) {
76- remaining /= shape[i];
77- strides->push_back (remaining);
78- }
79- }
80- internal::Permute (permutation, strides);
81-
82- return Status::OK ();
83- }
84-
85- } // namespace
40+ namespace arrow ::extension {
8641
8742bool FixedShapeTensorType::ExtensionEquals (const ExtensionType& other) const {
8843 if (extension_name () != other.extension_name ()) {
@@ -237,7 +192,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
237192 }
238193
239194 std::vector<int64_t > strides;
240- RETURN_NOT_OK (ComputeStrides (value_type, shape, permutation, &strides));
195+ RETURN_NOT_OK (
196+ internal::ComputeStrides (ext_type.value_type (), shape, permutation, &strides));
241197 const auto start_position = array->offset () * byte_width;
242198 const auto size = std::accumulate (shape.begin (), shape.end (), static_cast <int64_t >(1 ),
243199 std::multiplies<>());
@@ -376,9 +332,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
376332 internal::Permute<int64_t >(permutation, &shape);
377333
378334 std::vector<int64_t > tensor_strides;
379- const auto * fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get ());
380335 ARROW_RETURN_NOT_OK (
381- ComputeStrides (*fw_value_type , shape, permutation, &tensor_strides));
336+ internal:: ComputeStrides (value_type , shape, permutation, &tensor_strides));
382337
383338 const auto & raw_buffer = this ->storage ()->data ()->child_data [0 ]->buffers [1 ];
384339 ARROW_ASSIGN_OR_RAISE (
@@ -412,10 +367,9 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
412367
413368const std::vector<int64_t >& FixedShapeTensorType::strides () {
414369 if (strides_.empty ()) {
415- auto value_type = internal::checked_cast<FixedWidthType*>(this ->value_type_ .get ());
416370 std::vector<int64_t > tensor_strides;
417- ARROW_CHECK_OK (
418- ComputeStrides (*value_type, this -> shape (), this ->permutation (), &tensor_strides));
371+ ARROW_CHECK_OK (internal::ComputeStrides ( this -> value_type_ , this -> shape (),
372+ this ->permutation (), &tensor_strides));
419373 strides_ = tensor_strides;
420374 }
421375 return strides_;
@@ -430,5 +384,4 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
430384 return maybe_type.MoveValueUnsafe ();
431385}
432386
433- } // namespace extension
434- } // namespace arrow
387+ } // namespace arrow::extension
0 commit comments