Skip to content

Commit 55cc40b

Browse files
committed
Python wrapper
1 parent 2328b6e commit 55cc40b

24 files changed

+1559
-122
lines changed

cpp/src/arrow/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,8 @@ if(ARROW_JSON)
910910
arrow_add_object_library(ARROW_JSON
911911
extension/fixed_shape_tensor.cc
912912
extension/opaque.cc
913+
extension/tensor_internal.cc
914+
extension/variable_shape_tensor.cc
913915
json/options.cc
914916
json/chunked_builder.cc
915917
json/chunker.cc

cpp/src/arrow/extension/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
set(CANONICAL_EXTENSION_TESTS bool8_test.cc uuid_test.cc)
1919

2020
if(ARROW_JSON)
21-
list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc)
21+
list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc)
2222
endif()
2323

2424
add_arrow_test(test

cpp/src/arrow/extension/fixed_shape_tensor.cc

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -37,52 +37,7 @@
3737

3838
namespace rj = arrow::rapidjson;
3939

40-
namespace arrow {
41-
42-
namespace extension {
43-
44-
namespace {
45-
46-
Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
47-
const std::vector<int64_t>& permutation,
48-
std::vector<int64_t>* strides) {
49-
if (permutation.empty()) {
50-
return internal::ComputeRowMajorStrides(type, shape, strides);
51-
}
52-
53-
const int byte_width = type.byte_width();
54-
55-
int64_t remaining = 0;
56-
if (!shape.empty() && shape.front() > 0) {
57-
remaining = byte_width;
58-
for (auto i : permutation) {
59-
if (i > 0) {
60-
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
61-
return Status::Invalid(
62-
"Strides computed from shape would not fit in 64-bit integer");
63-
}
64-
}
65-
}
66-
}
67-
68-
if (remaining == 0) {
69-
strides->assign(shape.size(), byte_width);
70-
return Status::OK();
71-
}
72-
73-
strides->push_back(remaining);
74-
for (auto i : permutation) {
75-
if (i > 0) {
76-
remaining /= shape[i];
77-
strides->push_back(remaining);
78-
}
79-
}
80-
internal::Permute(permutation, strides);
81-
82-
return Status::OK();
83-
}
84-
85-
} // namespace
40+
namespace arrow::extension {
8641

8742
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
8843
if (extension_name() != other.extension_name()) {
@@ -237,7 +192,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
237192
}
238193

239194
std::vector<int64_t> strides;
240-
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
195+
RETURN_NOT_OK(
196+
internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides));
241197
const auto start_position = array->offset() * byte_width;
242198
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
243199
std::multiplies<>());
@@ -376,9 +332,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
376332
internal::Permute<int64_t>(permutation, &shape);
377333

378334
std::vector<int64_t> tensor_strides;
379-
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
380335
ARROW_RETURN_NOT_OK(
381-
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
336+
internal::ComputeStrides(value_type, shape, permutation, &tensor_strides));
382337

383338
const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
384339
ARROW_ASSIGN_OR_RAISE(
@@ -412,10 +367,9 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
412367

413368
const std::vector<int64_t>& FixedShapeTensorType::strides() {
414369
if (strides_.empty()) {
415-
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
416370
std::vector<int64_t> tensor_strides;
417-
ARROW_CHECK_OK(
418-
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
371+
ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(),
372+
this->permutation(), &tensor_strides));
419373
strides_ = tensor_strides;
420374
}
421375
return strides_;
@@ -430,5 +384,4 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
430384
return maybe_type.MoveValueUnsafe();
431385
}
432386

433-
} // namespace extension
434-
} // namespace arrow
387+
} // namespace arrow::extension

cpp/src/arrow/extension/fixed_shape_tensor.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
#include "arrow/extension_type.h"
2121

22-
namespace arrow {
23-
namespace extension {
22+
namespace arrow::extension {
2423

2524
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
2625
public:
@@ -126,5 +125,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
126125
const std::vector<int64_t>& permutation = {},
127126
const std::vector<std::string>& dim_names = {});
128127

129-
} // namespace extension
130-
} // namespace arrow
128+
} // namespace arrow::extension

0 commit comments

Comments
 (0)