diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 07ea8930ff0d7..143fb13ddcca6 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -520,6 +520,7 @@ endif() if(ARROW_JSON) list(APPEND ARROW_SRCS + extension/fixed_shape_tensor.cc json/options.cc json/chunked_builder.cc json/chunker.cc @@ -856,6 +857,7 @@ endif() if(ARROW_JSON) add_subdirectory(json) + add_subdirectory(extension) endif() if(ARROW_ORC) diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt new file mode 100644 index 0000000000000..c15c42874d4de --- /dev/null +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_arrow_test(test + SOURCES + fixed_shape_tensor_test.cc + PREFIX + "arrow-fixed-shape-tensor") + +arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc new file mode 100644 index 0000000000000..8b0ed43df5c66 --- /dev/null +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/util/int_util_overflow.h" +#include "arrow/util/logging.h" +#include "arrow/util/sort.h" + +#include +#include + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace extension { + +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = static_cast(other); + + auto is_permutation_trivial = [](const std::vector& permutation) { + for (size_t i = 1; i < permutation.size(); ++i) { + if (permutation[i - 1] + 1 != permutation[i]) { + return false; + } + } + return true; + }; + const bool permutation_equivalent = + ((permutation_ == other_ext.permutation()) || + (permutation_.empty() && is_permutation_trivial(other_ext.permutation())) || + (is_permutation_trivial(permutation_) && other_ext.permutation().empty())); + + return (storage_type()->Equals(other_ext.storage_type())) && + (this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) && + permutation_equivalent; +} + +std::string FixedShapeTensorType::Serialize() const { + rj::Document document; + document.SetObject(); + rj::Document::AllocatorType& allocator = document.GetAllocator(); + + rj::Value shape(rj::kArrayType); + for (auto v : shape_) { + shape.PushBack(v, allocator); + } + document.AddMember(rj::Value("shape", allocator), shape, allocator); + + if (!permutation_.empty()) { + rj::Value permutation(rj::kArrayType); + for (auto v : permutation_) { + permutation.PushBack(v, allocator); + } + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); + } + + if (!dim_names_.empty()) { + rj::Value dim_names(rj::kArrayType); + for (std::string v : dim_names_) { + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); + } + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); + } + + rj::StringBuffer buffer; + rj::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result> FixedShapeTensorType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::FIXED_SIZE_LIST) { + return Status::Invalid("Expected FixedSizeList storage type, got ", + storage_type->ToString()); + } + auto value_type = + internal::checked_pointer_cast(storage_type)->value_type(); + rj::Document document; + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || + !document.HasMember("shape") || !document["shape"].IsArray()) { + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); + } + + std::vector shape; + for (auto& x : document["shape"].GetArray()) { + shape.emplace_back(x.GetInt64()); + } + std::vector permutation; + if (document.HasMember("permutation")) { + for (auto& x : document["permutation"].GetArray()) { + permutation.emplace_back(x.GetInt64()); + } + if (shape.size() != permutation.size()) { + return Status::Invalid("Invalid permutation"); + } + } + std::vector dim_names; + if (document.HasMember("dim_names")) { + for (auto& x : document["dim_names"].GetArray()) { + dim_names.emplace_back(x.GetString()); + } + if (shape.size() != dim_names.size()) { + return Status::Invalid("Invalid dim_names"); + } + } + + return fixed_shape_tensor(value_type, shape, permutation, dim_names); +} + +std::shared_ptr FixedShapeTensorType::MakeArray( + std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.fixed_shape_tensor", + static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> FixedShapeTensorType::Make( + const std::shared_ptr& value_type, const std::vector& shape, + const std::vector& permutation, const std::vector& dim_names) { + if (!permutation.empty() && shape.size() != permutation.size()) { + return Status::Invalid("permutation size must match shape size. Expected: ", + shape.size(), " Got: ", permutation.size()); + } + if (!dim_names.empty() && shape.size() != dim_names.size()) { + return Status::Invalid("dim_names size must match shape size. Expected: ", + shape.size(), " Got: ", dim_names.size()); + } + const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), + std::multiplies<>()); + return std::make_shared(value_type, static_cast(size), + shape, permutation, dim_names); +} + +std::shared_ptr fixed_shape_tensor(const std::shared_ptr& value_type, + const std::vector& shape, + const std::vector& permutation, + const std::vector& dim_names) { + auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names); + ARROW_DCHECK_OK(maybe_type.status()); + return maybe_type.MoveValueUnsafe(); +} + +} // namespace extension +} // namespace arrow diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h new file mode 100644 index 0000000000000..4ee2b894ee8be --- /dev/null +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +/// This is a canonical arrow extension type. +/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr& value_type, const int32_t& size, + const std::vector& shape, + const std::vector& permutation = {}, + const std::vector& dim_names = {}) + : ExtensionType(fixed_size_list(value_type, size)), + value_type_(value_type), + shape_(shape), + permutation_(permutation), + dim_names_(dim_names) {} + + std::string extension_name() const override { return "arrow.fixed_shape_tensor"; } + + /// Number of dimensions of tensor elements + size_t ndim() { return shape_.size(); } + + /// Shape of tensor elements + const std::vector shape() const { return shape_; } + + /// Value type of tensor elements + const std::shared_ptr value_type() const { return value_type_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered physically. + const std::vector& dim_names() const { return dim_names_; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + + /// Create a FixedShapeTensorArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + /// \brief Create a FixedShapeTensorType instance + static Result> Make( + const std::shared_ptr& value_type, const std::vector& shape, + const std::vector& permutation = {}, + const std::vector& dim_names = {}); + + private: + std::shared_ptr storage_type_; + std::shared_ptr value_type_; + std::vector shape_; + std::vector permutation_; + std::vector dim_names_; +}; + +/// \brief Return a FixedShapeTensorType instance. +ARROW_EXPORT std::shared_ptr fixed_shape_tensor( + const std::shared_ptr& storage_type, const std::vector& shape, + const std::vector& permutation = {}, + const std::vector& dim_names = {}); + +} // namespace extension +} // namespace arrow diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc new file mode 100644 index 0000000000000..16ba9d2014e40 --- /dev/null +++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; +using extension::FixedShapeTensorArray; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = internal::checked_pointer_cast( + fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_)); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector shape_; + std::vector cell_shape_; + std::shared_ptr value_type_; + std::shared_ptr cell_type_; + std::vector dim_names_; + std::shared_ptr ext_type_; + std::vector values_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr& batch, + std::shared_ptr* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a registered dummy type at runtime to allow for IPC deserialization + auto registered_type = GetExtensionType("arrow.fixed_shape_tensor"); + ASSERT_TRUE(registered_type->type_id == Type::EXTENSION); +} + +TEST_F(TestExtensionType, CreateExtensionType) { + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); + ASSERT_FALSE(ext_type_->Equals(*cell_type_)); + ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_)); + ASSERT_EQ(ext_type_->Serialize(), serialized_); + ASSERT_OK_AND_ASSIGN(auto ds, + ext_type_->Deserialize(ext_type_->storage_type(), serialized_)); + auto deserialized = std::reinterpret_pointer_cast(ds); + ASSERT_TRUE(deserialized->Equals(*ext_type_)); + + // Test FixedShapeTensorType methods + ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION); + ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size()); + ASSERT_EQ(exact_ext_type->shape(), cell_shape_); + ASSERT_EQ(exact_ext_type->value_type(), value_type_); + ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: permutation size must match shape size."), + FixedShapeTensorType::Make(value_type_, cell_shape_, {0})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: dim_names size must match shape size."), + FixedShapeTensorType::Make(value_type_, cell_shape_, {}, {"x"})); +} + +TEST_F(TestExtensionType, EqualsCases) { + auto ext_type_permutation_1 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}, {"x", "y"}); + auto ext_type_permutation_2 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}, {"x", "y"}); + auto ext_type_no_permutation = fixed_shape_tensor(int64(), {3, 4}, {}, {"x", "y"}); + + ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1)); + + ASSERT_FALSE(fixed_shape_tensor(int32(), {3, 4}, {}, {"x", "y"}) + ->Equals(ext_type_no_permutation)); + ASSERT_FALSE(fixed_shape_tensor(int64(), {2, 4}, {}, {"x", "y"}) + ->Equals(ext_type_no_permutation)); + ASSERT_FALSE(fixed_shape_tensor(int64(), {3, 4}, {}, {"H", "W"}) + ->Equals(ext_type_no_permutation)); + + ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1)); + ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation)); + ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2)); + ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation)); + ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2)); + ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1)); +} + +TEST_F(TestExtensionType, CreateFromArray) { + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + + std::vector> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared(arr_data); + ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr); + ASSERT_EQ(ext_arr->length(), shape_[0]); + ASSERT_EQ(ext_arr->null_count(), 0); +} + +void CheckSerializationRoundtrip(const std::shared_ptr& ext_type) { + auto fst_type = internal::checked_pointer_cast(ext_type); + auto serialized = fst_type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + fst_type->Deserialize(fst_type->storage_type(), serialized)); + ASSERT_TRUE(fst_type->Equals(*deserialized)); +} + +void CheckDeserializationRaises(const std::shared_ptr& storage_type, + const std::string& serialized, + const std::string& expected_message) { + auto fst_type = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4})); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_message), + fst_type->Deserialize(storage_type, serialized)); +} + +TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { + CheckSerializationRoundtrip(ext_type_); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"})); + + auto storage_type = fixed_size_list(int64(), 12); + CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})", + "Expected FixedSizeList storage type, got bool"); + CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})", + "Invalid serialized JSON data"); + CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})", + "Invalid serialized JSON data"); + CheckDeserializationRaises(storage_type, R"({"shape":[3,4],"permutation":[1,0,2]})", + "Invalid permutation"); + CheckDeserializationRaises(storage_type, R"({"shape":[3],"dim_names":["x","y"]})", + "Invalid dim_names"); +} + +TEST_F(TestExtensionType, RoudtripBatch) { + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + + std::vector> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared(arr_data); + ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr); + + // Pass extension array, expect getting back extension array + std::shared_ptr read_batch; + auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); + + // Pass extension metadata and storage array, expect getting back extension array + std::shared_ptr read_batch2; + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + ext_field = field(/*name=*/"f0", /*type=*/cell_type_, /*nullable=*/true, + /*metadata=*/ext_metadata); + auto batch2 = RecordBatch::Make(schema({ext_field}), fsla_arr->length(), {fsla_arr}); + RoundtripBatch(batch2, &read_batch2); + CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); +} + +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index e579b6910231f..1199336763ddb 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -26,6 +26,10 @@ #include "arrow/array/util.h" #include "arrow/chunked_array.h" +#include "arrow/config.h" +#ifdef ARROW_JSON +#include "arrow/extension/fixed_shape_tensor.h" +#endif #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -139,6 +143,14 @@ namespace internal { static void CreateGlobalRegistry() { g_registry = std::make_shared(); + +#ifdef ARROW_JSON + // Register canonical extension types + auto ext_type = + checked_pointer_cast(extension::fixed_shape_tensor(int64(), {})); + + ARROW_CHECK_OK(g_registry->RegisterType(ext_type)); +#endif } } // namespace internal