From b0317f2b2b62b3be9beb8d834aa51b776fb0179e Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 20 Aug 2024 17:04:33 +0900 Subject: [PATCH 01/32] GH-43707: [Python] Fix compilation on Cython<3 (#43765) ### Rationale for this change Fix compilation on Cython < 3 ### What changes are included in this PR? Add an explicit cast ### Are these changes tested? N/A ### Are there any user-facing changes? No * GitHub Issue: #43707 Authored-by: David Li Signed-off-by: Joris Van den Bossche --- python/pyarrow/types.pxi | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 93d68fb847890..dcd2b61c33411 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -5328,8 +5328,9 @@ def opaque(DataType storage_type, str type_name not None, str vendor_name not No cdef: c_string c_type_name = tobytes(type_name) c_string c_vendor_name = tobytes(vendor_name) - shared_ptr[CDataType] c_type = make_shared[COpaqueType]( + shared_ptr[COpaqueType] c_opaque_type = make_shared[COpaqueType]( storage_type.sp_type, c_type_name, c_vendor_name) + shared_ptr[CDataType] c_type = static_pointer_cast[CDataType, COpaqueType](c_opaque_type) OpaqueType out = OpaqueType.__new__(OpaqueType) out.init(c_type) return out From cc3c868aea7317a58447658f1c165ad352cd4865 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:57:57 +0200 Subject: [PATCH 02/32] MINOR: [Documentation] Add installation of ninja-build to Python Development docs (#43600) ### Rationale for this change Otherwise, you get a CMake error: ``` CMake Error: CMake was unable to find a build program corresponding to "Ninja". CMAKE_MAKE_PROGRAM is not set. You probably need to select a different build tool. ``` Authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Signed-off-by: Joris Van den Bossche --- docs/source/developers/python.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/developers/python.rst b/docs/source/developers/python.rst index 2f3e892ce8ede..6beea55e66b86 100644 --- a/docs/source/developers/python.rst +++ b/docs/source/developers/python.rst @@ -267,7 +267,7 @@ On Debian/Ubuntu, you need the following minimal set of dependencies: .. code-block:: - $ sudo apt-get install build-essential cmake python3-dev + $ sudo apt-get install build-essential ninja-build cmake python3-dev Now, let's create a Python virtual environment with all Python dependencies in the same folder as the repositories, and a target installation folder: From 525881987d0b9b4f464c3e3593a9a7b4e3c767d0 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:25:19 -0400 Subject: [PATCH 03/32] GH-17682: [C++][Python] Bool8 Extension Type Implementation (#43488) ### Rationale for this change C++ and Python implementations of #43234 ### What changes are included in this PR? - Implement C++ `Bool8Type`, `Bool8Array`, `Bool8Scalar`, and tests - Implement Python bindings to C++, as well as zero-copy numpy conversion methods - TODO: docs waiting for rebase on #43458 ### Are these changes tested? Yes ### Are there any user-facing changes? Bool8 extension type will be available in C++ and Python libraries * GitHub Issue: #17682 Authored-by: Joel Lubinitsky Signed-off-by: Felipe Oliveira Carvalho --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 6 + cpp/src/arrow/extension/bool8.cc | 61 ++++++++ cpp/src/arrow/extension/bool8.h | 58 ++++++++ cpp/src/arrow/extension/bool8_test.cc | 91 ++++++++++++ cpp/src/arrow/extension_type.cc | 7 +- python/pyarrow/__init__.py | 7 +- python/pyarrow/array.pxi | 114 ++++++++++++++- python/pyarrow/includes/libarrow.pxd | 9 ++ python/pyarrow/lib.pxd | 3 + python/pyarrow/public-api.pxi | 2 + python/pyarrow/scalar.pxi | 23 ++- python/pyarrow/tests/test_extension_type.py | 152 ++++++++++++++++++++ python/pyarrow/tests/test_misc.py | 3 + python/pyarrow/types.pxi | 74 ++++++++++ 15 files changed, 604 insertions(+), 7 deletions(-) create mode 100644 cpp/src/arrow/extension/bool8.cc create mode 100644 cpp/src/arrow/extension/bool8.h create mode 100644 cpp/src/arrow/extension/bool8_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index fb785e1e9571b..fb7253b6fd69d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -906,6 +906,7 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON + extension/bool8.cc extension/fixed_shape_tensor.cc extension/opaque.cc json/options.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 6741ab602f50b..fcd5fa529ab56 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,6 +15,12 @@ # specific language governing permissions and limitations # under the License. +add_arrow_test(test + SOURCES + bool8_test.cc + PREFIX + "arrow-extension-bool8") + add_arrow_test(test SOURCES fixed_shape_tensor_test.cc diff --git a/cpp/src/arrow/extension/bool8.cc b/cpp/src/arrow/extension/bool8.cc new file mode 100644 index 0000000000000..c081f0c2b2866 --- /dev/null +++ b/cpp/src/arrow/extension/bool8.cc @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/extension/bool8.h" +#include "arrow/util/logging.h" + +namespace arrow::extension { + +bool Bool8Type::ExtensionEquals(const ExtensionType& other) const { + return extension_name() == other.extension_name(); +} + +std::string Bool8Type::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() << ">"; + return ss.str(); +} + +std::string Bool8Type::Serialize() const { return ""; } + +Result> Bool8Type::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::INT8) { + return Status::Invalid("Expected INT8 storage type, got ", storage_type->ToString()); + } + if (serialized_data != "") { + return Status::Invalid("Serialize data must be empty, got ", serialized_data); + } + return bool8(); +} + +std::shared_ptr Bool8Type::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.bool8", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> Bool8Type::Make() { + return std::make_shared(); +} + +std::shared_ptr bool8() { return std::make_shared(); } + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/bool8.h b/cpp/src/arrow/extension/bool8.h new file mode 100644 index 0000000000000..02e629b28a867 --- /dev/null +++ b/cpp/src/arrow/extension/bool8.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" + +namespace arrow::extension { + +/// \brief Bool8 is an alternate representation for boolean +/// arrays using 8 bits instead of 1 bit per value. The underlying +/// storage type is int8. +class ARROW_EXPORT Bool8Array : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Bool8 is an alternate representation for boolean +/// arrays using 8 bits instead of 1 bit per value. The underlying +/// storage type is int8. +class ARROW_EXPORT Bool8Type : public ExtensionType { + public: + /// \brief Construct a Bool8Type. + Bool8Type() : ExtensionType(int8()) {} + + std::string extension_name() const override { return "arrow.bool8"; } + std::string ToString(bool show_metadata = false) const override; + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + + /// Create a Bool8Array from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + static Result> Make(); +}; + +/// \brief Return a Bool8Type instance. +ARROW_EXPORT std::shared_ptr bool8(); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/bool8_test.cc b/cpp/src/arrow/extension/bool8_test.cc new file mode 100644 index 0000000000000..eabcfcf62d32c --- /dev/null +++ b/cpp/src/arrow/extension/bool8_test.cc @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/bool8.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { + +TEST(Bool8Type, Basics) { + auto type = internal::checked_pointer_cast(extension::bool8()); + auto type2 = internal::checked_pointer_cast(extension::bool8()); + ASSERT_EQ("arrow.bool8", type->extension_name()); + ASSERT_EQ(*type, *type); + ASSERT_NE(*arrow::null(), *type); + ASSERT_EQ(*type, *type2); + ASSERT_EQ(*arrow::int8(), *type->storage_type()); + ASSERT_EQ("", type->Serialize()); + ASSERT_EQ("extension", type->ToString(false)); +} + +TEST(Bool8Type, CreateFromArray) { + auto type = internal::checked_pointer_cast(extension::bool8()); + auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]"); + auto array = ExtensionType::WrapArray(type, storage); + ASSERT_EQ(5, array->length()); + ASSERT_EQ(1, array->null_count()); +} + +TEST(Bool8Type, Deserialize) { + auto type = internal::checked_pointer_cast(extension::bool8()); + ASSERT_OK_AND_ASSIGN(auto deserialized, type->Deserialize(type->storage_type(), "")); + ASSERT_EQ(*type, *deserialized); + ASSERT_NOT_OK(type->Deserialize(type->storage_type(), "must be empty")); + ASSERT_EQ(*type, *deserialized); + ASSERT_NOT_OK(type->Deserialize(uint8(), "")); + ASSERT_EQ(*type, *deserialized); +} + +TEST(Bool8Type, MetadataRoundTrip) { + auto type = internal::checked_pointer_cast(extension::bool8()); + std::string serialized = type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + type->Deserialize(type->storage_type(), serialized)); + ASSERT_EQ(*type, *deserialized); +} + +TEST(Bool8Type, BatchRoundTrip) { + auto type = internal::checked_pointer_cast(extension::bool8()); + + auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]"); + auto array = ExtensionType::WrapArray(type, storage); + auto batch = + RecordBatch::Make(schema({field("field", type)}), array->length(), {array}); + + std::shared_ptr written; + { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(&written)); + } + + ASSERT_EQ(*batch->schema(), *written->schema()); + ASSERT_BATCHES_EQUAL(*batch, *written); +} + +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index cf8dda7a85df4..685018f7de7b8 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -28,6 +28,7 @@ #include "arrow/chunked_array.h" #include "arrow/config.h" #ifdef ARROW_JSON +#include "arrow/extension/bool8.h" #include "arrow/extension/fixed_shape_tensor.h" #endif #include "arrow/status.h" @@ -146,10 +147,12 @@ static void CreateGlobalRegistry() { #ifdef ARROW_JSON // Register canonical extension types - auto ext_type = + auto fst_ext_type = checked_pointer_cast(extension::fixed_shape_tensor(int64(), {})); + ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type)); - ARROW_CHECK_OK(g_registry->RegisterType(ext_type)); + auto bool8_ext_type = checked_pointer_cast(extension::bool8()); + ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type)); #endif } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index aa7bab9f97e05..807bcdc315036 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -174,6 +174,7 @@ def print_entry(label, value): run_end_encoded, fixed_shape_tensor, opaque, + bool8, field, type_for_alias, DataType, DictionaryType, StructType, @@ -184,7 +185,7 @@ def print_entry(label, value): FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, RunEndEncodedType, FixedShapeTensorType, OpaqueType, - PyExtensionType, UnknownExtensionType, + Bool8Type, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, KeyValueMetadata, @@ -218,7 +219,7 @@ def print_entry(label, value): MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, - scalar, NA, _NULL as NULL, Scalar, + Bool8Array, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, @@ -235,7 +236,7 @@ def print_entry(label, value): FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, RunEndEncodedScalar, ExtensionScalar, - FixedShapeTensorScalar, OpaqueScalar) + FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 6c40a21db96ca..4c3eb93232634 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1581,7 +1581,7 @@ cdef class Array(_PandasConvertible): def to_numpy(self, zero_copy_only=True, writable=False): """ - Return a NumPy view or copy of this array (experimental). + Return a NumPy view or copy of this array. By default, tries to return a view of this array. This is only supported for primitive arrays with the same memory layout as NumPy @@ -4476,6 +4476,118 @@ cdef class OpaqueArray(ExtensionArray): """ +cdef class Bool8Array(ExtensionArray): + """ + Concrete class for bool8 extension arrays. + + Examples + -------- + Define the extension type for an bool8 array + + >>> import pyarrow as pa + >>> bool8_type = pa.bool8() + + Create an extension array + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> pa.ExtensionArray.from_storage(bool8_type, storage) + + [ + -1, + 0, + 1, + 2, + null + ] + """ + + def to_numpy(self, zero_copy_only=True, writable=False): + """ + Return a NumPy bool view or copy of this array. + + By default, tries to return a view of this array. This is only + supported for arrays without any nulls. + + Parameters + ---------- + zero_copy_only : bool, default True + If True, an exception will be raised if the conversion to a numpy + array would require copying the underlying data (e.g. in presence + of nulls). + writable : bool, default False + For numpy arrays created with zero copy (view on the Arrow data), + the resulting array is not writable (Arrow data is immutable). + By setting this to True, a copy of the array is made to ensure + it is writable. + + Returns + ------- + array : numpy.ndarray + """ + if not writable: + try: + return self.storage.to_numpy().view(np.bool_) + except ArrowInvalid as e: + if zero_copy_only: + raise e + + return _pc().not_equal(self.storage, 0).to_numpy(zero_copy_only=zero_copy_only, writable=writable) + + @staticmethod + def from_storage(Int8Array storage): + """ + Construct Bool8Array from Int8Array storage. + + Parameters + ---------- + storage : Int8Array + The underlying storage for the result array. + + Returns + ------- + bool8_array : Bool8Array + """ + return ExtensionArray.from_storage(bool8(), storage) + + @staticmethod + def from_numpy(obj): + """ + Convert numpy array to a bool8 extension array without making a copy. + The input array must be 1-dimensional, with either bool_ or int8 dtype. + + Parameters + ---------- + obj : numpy.ndarray + + Returns + ------- + bool8_array : Bool8Array + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> arr = np.array([True, False, True], dtype=np.bool_) + >>> pa.Bool8Array.from_numpy(arr) + + [ + 1, + 0, + 1 + ] + """ + + if obj.ndim != 1: + raise ValueError(f"Cannot convert {obj.ndim}-D array to bool8 array") + + if obj.dtype not in [np.bool_, np.int8]: + raise TypeError(f"Array dtype {obj.dtype} incompatible with bool8 storage") + + storage_arr = array(obj.view(np.int8), type=int8()) + return Bool8Array.from_storage(storage_arr) + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9b008d150f1f1..a54a1db292f70 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2895,6 +2895,15 @@ cdef extern from "arrow/extension/opaque.h" namespace "arrow::extension" nogil: pass +cdef extern from "arrow/extension/bool8.h" namespace "arrow::extension" nogil: + cdef cppclass CBool8Type" arrow::extension::Bool8Type"(CExtensionType): + + @staticmethod + CResult[shared_ptr[CDataType]] Make() + + cdef cppclass CBool8Array" arrow::extension::Bool8Array"(CExtensionArray): + pass + cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: cdef enum CCompressionType" arrow::Compression::type": CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 2cb302d20a8ac..e3625c1815274 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -214,6 +214,9 @@ cdef class FixedShapeTensorType(BaseExtensionType): cdef: const CFixedShapeTensorType* tensor_ext_type +cdef class Bool8Type(BaseExtensionType): + cdef: + const CBool8Type* bool8_ext_type cdef class OpaqueType(BaseExtensionType): cdef: diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 2f9fc1c554209..19a26bd6c683d 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -126,6 +126,8 @@ cdef api object pyarrow_wrap_data_type( out = FixedShapeTensorType.__new__(FixedShapeTensorType) elif ext_type.extension_name() == b"arrow.opaque": out = OpaqueType.__new__(OpaqueType) + elif ext_type.extension_name() == b"arrow.bool8": + out = Bool8Type.__new__(Bool8Type) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 12a99c2aece63..72ae2aee5f8b3 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1091,6 +1091,18 @@ cdef class OpaqueScalar(ExtensionScalar): """ +cdef class Bool8Scalar(ExtensionScalar): + """ + Concrete class for bool8 extension scalar. + """ + + def as_py(self): + """ + Return this scalar as a Python object. + """ + py_val = super().as_py() + return None if py_val is None else py_val != 0 + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, @@ -1199,6 +1211,11 @@ def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): type = ensure_type(type, allow_none=True) pool = maybe_unbox_memory_pool(memory_pool) + extension_type = None + if type is not None and type.id == _Type_EXTENSION: + extension_type = type + type = type.storage_type + if _is_array_like(value): value = get_values(value, &is_pandas_object) @@ -1223,4 +1240,8 @@ def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): # retrieve the scalar from the first position scalar = GetResultValue(array.get().GetScalar(0)) - return Scalar.wrap(scalar) + result = Scalar.wrap(scalar) + + if extension_type is not None: + result = ExtensionScalar.from_storage(extension_type, result) + return result diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 58c54189f223e..b04ee85ec99ad 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1707,3 +1707,155 @@ def test_opaque_type(pickle_module, storage_type, storage): # cast extension type -> storage type inner = arr.cast(storage_type) assert inner == storage + + +def test_bool8_type(pickle_module): + bool8_type = pa.bool8() + storage_type = pa.int8() + assert bool8_type.extension_name == "arrow.bool8" + assert bool8_type.storage_type == storage_type + assert str(bool8_type) == "extension" + + assert bool8_type == bool8_type + assert bool8_type == pa.bool8() + assert bool8_type != storage_type + + # Pickle roundtrip + result = pickle_module.loads(pickle_module.dumps(bool8_type)) + assert result == bool8_type + + # IPC roundtrip + storage = pa.array([-1, 0, 1, 2, None], storage_type) + arr = pa.ExtensionArray.from_storage(bool8_type, storage) + assert isinstance(arr, pa.Bool8Array) + + # extension is registered by default + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) + + assert batch.column(0).type.extension_name == "arrow.bool8" + assert isinstance(batch.column(0), pa.Bool8Array) + + # cast storage -> extension type + result = storage.cast(bool8_type) + assert result == arr + + # cast extension type -> storage type + inner = arr.cast(storage_type) + assert inner == storage + + +def test_bool8_to_bool_conversion(): + bool_arr = pa.array([True, False, True, True, None], pa.bool_()) + bool8_arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + + # cast extension type -> arrow boolean type + assert bool8_arr.cast(pa.bool_()) == bool_arr + + # cast arrow boolean type -> extension type, expecting canonical values + canonical_storage = pa.array([1, 0, 1, 1, None], pa.int8()) + canonical_bool8_arr = pa.ExtensionArray.from_storage(pa.bool8(), canonical_storage) + assert bool_arr.cast(pa.bool8()) == canonical_bool8_arr + + +def test_bool8_to_numpy_conversion(): + arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + + # cannot zero-copy with nulls + with pytest.raises( + pa.ArrowInvalid, + match="Needed to copy 1 chunks with 1 nulls, but zero_copy_only was True", + ): + arr.to_numpy() + + # nullable conversion possible with a copy, but dest dtype is object + assert np.array_equal( + arr.to_numpy(zero_copy_only=False), + np.array([True, False, True, True, None], dtype=np.object_), + ) + + # zero-copy possible with non-null array + np_arr_no_nulls = np.array([True, False, True, True], dtype=np.bool_) + arr_no_nulls = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2], pa.int8()), + ) + + arr_to_np = arr_no_nulls.to_numpy() + assert np.array_equal(arr_to_np, np_arr_no_nulls) + + # same underlying buffer + assert arr_to_np.ctypes.data == arr_no_nulls.buffers()[1].address + + # if the user requests a writable array, a copy should be performed + arr_to_np_writable = arr_no_nulls.to_numpy(zero_copy_only=False, writable=True) + assert np.array_equal(arr_to_np_writable, np_arr_no_nulls) + + # different underlying buffer + assert arr_to_np_writable.ctypes.data != arr_no_nulls.buffers()[1].address + + +def test_bool8_from_numpy_conversion(): + np_arr_no_nulls = np.array([True, False, True, True], dtype=np.bool_) + canonical_bool8_arr_no_nulls = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([1, 0, 1, 1], pa.int8()), + ) + + arr_from_np = pa.Bool8Array.from_numpy(np_arr_no_nulls) + assert arr_from_np == canonical_bool8_arr_no_nulls + + # same underlying buffer + assert arr_from_np.buffers()[1].address == np_arr_no_nulls.ctypes.data + + # conversion only valid for 1-D arrays + with pytest.raises( + ValueError, + match="Cannot convert 2-D array to bool8 array", + ): + pa.Bool8Array.from_numpy( + np.array([[True, False], [False, True]], dtype=np.bool_), + ) + + with pytest.raises( + ValueError, + match="Cannot convert 0-D array to bool8 array", + ): + pa.Bool8Array.from_numpy(np.bool_()) + + # must use compatible storage type + with pytest.raises( + TypeError, + match="Array dtype float64 incompatible with bool8 storage", + ): + pa.Bool8Array.from_numpy(np.array([1, 2, 3], dtype=np.float64)) + + +def test_bool8_scalar(): + assert pa.ExtensionScalar.from_storage(pa.bool8(), -1).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), 0).as_py() is False + assert pa.ExtensionScalar.from_storage(pa.bool8(), 1).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), 2).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), None).as_py() is None + + arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + assert arr[0].as_py() is True + assert arr[1].as_py() is False + assert arr[2].as_py() is True + assert arr[3].as_py() is True + assert arr[4].as_py() is None + + assert pa.scalar(-1, type=pa.bool8()).as_py() is True + assert pa.scalar(0, type=pa.bool8()).as_py() is False + assert pa.scalar(1, type=pa.bool8()).as_py() is True + assert pa.scalar(2, type=pa.bool8()).as_py() is True + assert pa.scalar(None, type=pa.bool8()).as_py() is None diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index 9a55a38177fc8..5d3471c7c35db 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -250,6 +250,9 @@ def test_set_timezone_db_path_non_windows(): pa.OpaqueArray, pa.OpaqueScalar, pa.OpaqueType, + pa.Bool8Array, + pa.Bool8Scalar, + pa.Bool8Type, ]) def test_extension_type_constructor_errors(klass): # ARROW-2638: prevent calling extension class constructors directly diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index dcd2b61c33411..563782f0c2643 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1837,6 +1837,37 @@ cdef class FixedShapeTensorType(BaseExtensionType): return FixedShapeTensorScalar +cdef class Bool8Type(BaseExtensionType): + """ + Concrete class for bool8 extension type. + + Bool8 is an alternate representation for boolean + arrays using 8 bits instead of 1 bit per value. The underlying + storage type is int8. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> pa.bool8() + Bool8Type(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.bool8_ext_type = type.get() + + def __arrow_ext_class__(self): + return Bool8Array + + def __reduce__(self): + return bool8, () + + def __arrow_ext_scalar_class__(self): + return Bool8Scalar + + cdef class OpaqueType(BaseExtensionType): """ Concrete class for opaque extension type. @@ -5278,6 +5309,49 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N return out +def bool8(): + """ + Create instance of bool8 extension type. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> type = pa.bool8() + >>> type + Bool8Type(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(int8) + + Create a table with a bool8 array: + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[-1,0,1,2,null]] + + Returns + ------- + type : Bool8Type + """ + + cdef Bool8Type out = Bool8Type.__new__(Bool8Type) + + c_type = GetResultValue(CBool8Type.Make()) + + out.init(c_type) + + return out + + def opaque(DataType storage_type, str type_name not None, str vendor_name not None): """ Create instance of opaque extension type. From 27c22389579dd773d9701f5d3c743bbfca3bdb8e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:38:12 +0900 Subject: [PATCH 04/32] MINOR: [Java] Bump org.codehaus.mojo:exec-maven-plugin from 3.3.0 to 3.4.1 in /java (#43692) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [org.codehaus.mojo:exec-maven-plugin](https://github.com/mojohaus/exec-maven-plugin) from 3.3.0 to 3.4.1.
Release notes

Sourced from org.codehaus.mojo:exec-maven-plugin's releases.

3.4.1

🐛 Bug Fixes

📦 Dependency updates

👻 Maintenance

🔧 Build

3.4.0

🚀 New features and improvements

  • Allow <includePluginDependencies> to be specified for the exec:exec goal (#432) @​sebthom

🐛 Bug Fixes

📦 Dependency updates

👻 Maintenance

🔧 Build

Commits
  • 7b0be2c [maven-release-plugin] prepare release 3.4.1
  • 5ac4f80 Environment variable Path should be used as case-insensitive
  • cfb3a9f Use Maven4 enabled with GH Action
  • d0ded48 Use shared release drafter GH Action
  • 4c22954 Bump org.codehaus.mojo:mojo-parent from 84 to 85
  • a8c4f94 fix: NPE because declared MavenSession field hides field of superclass
  • a2b735f Remove redundant spotless configuration
  • 8e0e83c [maven-release-plugin] prepare for next development iteration
  • 6c4996f [maven-release-plugin] prepare release 3.4.0
  • c7ad671 Remove Log4j 1.2.x from ITs
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.codehaus.mojo:exec-maven-plugin&package-manager=maven&previous-version=3.3.0&new-version=3.4.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 1524dc3257997..0f3e5760f2b82 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -504,7 +504,7 @@ under the License. org.codehaus.mojo exec-maven-plugin - 3.3.0 + 3.4.1 org.codehaus.mojo From 4af1e491df7ac22217656668b65c3e8d55f5b5ab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:56:44 +0900 Subject: [PATCH 05/32] MINOR: [Java] Bump io.grpc:grpc-bom from 1.65.0 to 1.66.0 in /java (#43657) Bumps [io.grpc:grpc-bom](https://github.com/grpc/grpc-java) from 1.65.0 to 1.66.0.
Release notes

Sourced from io.grpc:grpc-bom's releases.

v1.65.1

What's Changed

  • netty: Restore old behavior of NettyAdaptiveCumulator, but avoid using that class if Netty is on version 4.1.111 or later
Commits
  • cf78406 Bump version to 1.66.0
  • 33af0a7 Update README etc to reference 1.66.0
  • 19c9b99 xds: XdsClient should unsubscribe on last resource (#11264)
  • 752a045 Revert "Start 1.67.0 development cycle (#11416)" (#11428)
  • ef09d94 Revert "Introduce onResult2 in NameResolver Listener2 that returns Status (#1...
  • c37fb18 Start 1.67.0 development cycle
  • 9ba2f9d Introduce onResult2 in NameResolver Listener2 that returns Status (#11313)
  • 786523d xds: WRR rr_fallback should trigger with one endpoint weight
  • b108ed3 api: Give instruments a toString() including their name
  • eb4cdf7 Update MAINTAINERS.md (#11241)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=io.grpc:grpc-bom&package-manager=maven&previous-version=1.65.0&new-version=1.66.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 0f3e5760f2b82..a73453df68fd2 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -97,7 +97,7 @@ under the License. 2.0.13 33.2.1-jre 4.1.112.Final - 1.65.0 + 1.66.0 3.25.4 2.17.2 3.4.0 From 9fc03015463a8f1cb616b088342b104fbc767a0c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 21 Aug 2024 09:22:53 +0200 Subject: [PATCH 06/32] GH-43069: [Python] Use Py_IsFinalizing from pythoncapi_compat.h (#43767) ### Rationale for this change https://github.com/apache/arrow/pull/43540 already vendored `pythoncapi_compat.h`, so closing https://github.com/apache/arrow/issues/43069 by using this as well for `Py_IsFinalizing` (which was added in https://github.com/apache/arrow/pull/42034, and for which we opened that follow-up issue to use `pythoncapi_compat.h` instead) Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/src/arrow/python/udf.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 2c1e97c3ea03d..74f16899c47eb 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -24,14 +24,11 @@ #include "arrow/compute/kernel.h" #include "arrow/compute/row/grouper.h" #include "arrow/python/common.h" +#include "arrow/python/vendored/pythoncapi_compat.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -// Py_IsFinalizing added in Python 3.13.0a4 -#if PY_VERSION_HEX < 0x030D00A4 -#define Py_IsFinalizing() _Py_IsFinalizing() -#endif namespace arrow { using compute::ExecSpan; using compute::Grouper; From e1e7c501019ac26c896d61fa0c129eee83da9b55 Mon Sep 17 00:00:00 2001 From: Oliver Layer Date: Wed, 21 Aug 2024 13:22:57 +0200 Subject: [PATCH 07/32] GH-40036: [C++] Azure file system write buffering & async writes (#43096) ### Rationale for this change See #40036. ### What changes are included in this PR? Write buffering and async writes (similar to what the S3 file system does) in the `ObjectAppendStream` for the Azure file system. With write buffering and async writes, the input scenario creation runtime in the tests (which uses the `ObjectAppendStream` against Azurite) decreased from ~25s (see [here](https://github.com/apache/arrow/issues/40036)) to ~800ms: ``` [ RUN ] TestAzuriteFileSystem.OpenInputFileMixedReadVsReadAt [ OK ] TestAzuriteFileSystem.OpenInputFileMixedReadVsReadAt (787 ms) ``` ### Are these changes tested? Added some tests with background writes enabled and disabled (some were taken from the S3 tests). Everything changed should be covered. ### Are there any user-facing changes? `AzureOptions` now allows for `background_writes` to be set (default: true). No breaking changes. ### Notes - The code in `DoWrite` is very similar to [the code in the S3 FS](https://github.com/apache/arrow/blob/edfa343eeca008513f0300924380e1b187cc976b/cpp/src/arrow/filesystem/s3fs.cc#L1753). Maybe this could be unified? I didn't see this in the scope of the PR though. * GitHub Issue: #40036 Lead-authored-by: Oliver Layer Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/filesystem/azurefs.cc | 276 ++++++++++++++++++++--- cpp/src/arrow/filesystem/azurefs.h | 3 + cpp/src/arrow/filesystem/azurefs_test.cc | 264 ++++++++++++++++++---- 3 files changed, 471 insertions(+), 72 deletions(-) diff --git a/cpp/src/arrow/filesystem/azurefs.cc b/cpp/src/arrow/filesystem/azurefs.cc index 9b3c0c0c1d703..0bad856339729 100644 --- a/cpp/src/arrow/filesystem/azurefs.cc +++ b/cpp/src/arrow/filesystem/azurefs.cc @@ -22,6 +22,7 @@ #include "arrow/filesystem/azurefs.h" #include "arrow/filesystem/azurefs_internal.h" +#include "arrow/io/memory.h" // idenfity.hpp triggers -Wattributes warnings cause -Werror builds to fail, // so disable it for this file with pragmas. @@ -144,6 +145,9 @@ Status AzureOptions::ExtractFromUriQuery(const Uri& uri) { blob_storage_scheme = "http"; dfs_storage_scheme = "http"; } + } else if (kv.first == "background_writes") { + ARROW_ASSIGN_OR_RAISE(background_writes, + ::arrow::internal::ParseBoolean(kv.second)); } else { return Status::Invalid( "Unexpected query parameter in Azure Blob File System URI: '", kv.first, "'"); @@ -937,8 +941,8 @@ Status CommitBlockList(std::shared_ptr block_bl const std::vector& block_ids, const Blobs::CommitBlockListOptions& options) { try { - // CommitBlockList puts all block_ids in the latest element. That means in the case of - // overlapping block_ids the newly staged block ids will always replace the + // CommitBlockList puts all block_ids in the latest element. That means in the case + // of overlapping block_ids the newly staged block ids will always replace the // previously committed blocks. // https://learn.microsoft.com/en-us/rest/api/storageservices/put-block-list?tabs=microsoft-entra-id#request-body block_blob_client->CommitBlockList(block_ids, options); @@ -950,7 +954,34 @@ Status CommitBlockList(std::shared_ptr block_bl return Status::OK(); } +Status StageBlock(Blobs::BlockBlobClient* block_blob_client, const std::string& id, + Core::IO::MemoryBodyStream& content) { + try { + block_blob_client->StageBlock(id, content); + } catch (const Storage::StorageException& exception) { + return ExceptionToStatus( + exception, "StageBlock failed for '", block_blob_client->GetUrl(), + "' new_block_id: '", id, + "'. Staging new blocks is fundamental to streaming writes to blob storage."); + } + + return Status::OK(); +} + +/// Writes will be buffered up to this size (in bytes) before actually uploading them. +static constexpr int64_t kBlockUploadSizeBytes = 10 * 1024 * 1024; +/// The maximum size of a block in Azure Blob (as per docs). +static constexpr int64_t kMaxBlockSizeBytes = 4UL * 1024 * 1024 * 1024; + +/// This output stream, similar to other arrow OutputStreams, is not thread-safe. class ObjectAppendStream final : public io::OutputStream { + private: + struct UploadState; + + std::shared_ptr Self() { + return std::dynamic_pointer_cast(shared_from_this()); + } + public: ObjectAppendStream(std::shared_ptr block_blob_client, const io::IOContext& io_context, const AzureLocation& location, @@ -958,7 +989,8 @@ class ObjectAppendStream final : public io::OutputStream { const AzureOptions& options) : block_blob_client_(std::move(block_blob_client)), io_context_(io_context), - location_(location) { + location_(location), + background_writes_(options.background_writes) { if (metadata && metadata->size() != 0) { ArrowMetadataToCommitBlockListOptions(metadata, commit_block_list_options_); } else if (options.default_metadata && options.default_metadata->size() != 0) { @@ -1008,10 +1040,13 @@ class ObjectAppendStream final : public io::OutputStream { content_length_ = 0; } } + + upload_state_ = std::make_shared(); + if (content_length_ > 0) { ARROW_ASSIGN_OR_RAISE(auto block_list, GetBlockList(block_blob_client_)); for (auto block : block_list.CommittedBlocks) { - block_ids_.push_back(block.Name); + upload_state_->block_ids.push_back(block.Name); } } initialised_ = true; @@ -1031,12 +1066,34 @@ class ObjectAppendStream final : public io::OutputStream { if (closed_) { return Status::OK(); } + + if (current_block_) { + // Upload remaining buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + RETURN_NOT_OK(Flush()); block_blob_client_ = nullptr; closed_ = true; return Status::OK(); } + Future<> CloseAsync() override { + if (closed_) { + return Status::OK(); + } + + if (current_block_) { + // Upload remaining buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + + return FlushAsync().Then([self = Self()]() { + self->block_blob_client_ = nullptr; + self->closed_ = true; + }); + } + bool closed() const override { return closed_; } Status CheckClosed(const char* action) const { @@ -1052,11 +1109,11 @@ class ObjectAppendStream final : public io::OutputStream { } Status Write(const std::shared_ptr& buffer) override { - return DoAppend(buffer->data(), buffer->size(), buffer); + return DoWrite(buffer->data(), buffer->size(), buffer); } Status Write(const void* data, int64_t nbytes) override { - return DoAppend(data, nbytes); + return DoWrite(data, nbytes); } Status Flush() override { @@ -1066,20 +1123,111 @@ class ObjectAppendStream final : public io::OutputStream { // flush. This also avoids some unhandled errors when flushing in the destructor. return Status::OK(); } - return CommitBlockList(block_blob_client_, block_ids_, commit_block_list_options_); + + Future<> pending_blocks_completed; + { + std::unique_lock lock(upload_state_->mutex); + pending_blocks_completed = upload_state_->pending_blocks_completed; + } + + RETURN_NOT_OK(pending_blocks_completed.status()); + std::unique_lock lock(upload_state_->mutex); + return CommitBlockList(block_blob_client_, upload_state_->block_ids, + commit_block_list_options_); } - private: - Status DoAppend(const void* data, int64_t nbytes, - std::shared_ptr owned_buffer = nullptr) { - RETURN_NOT_OK(CheckClosed("append")); - auto append_data = reinterpret_cast(data); - Core::IO::MemoryBodyStream block_content(append_data, nbytes); - if (block_content.Length() == 0) { + Future<> FlushAsync() { + RETURN_NOT_OK(CheckClosed("flush async")); + if (!initialised_) { + // If the stream has not been successfully initialized then there is nothing to + // flush. This also avoids some unhandled errors when flushing in the destructor. return Status::OK(); } - const auto n_block_ids = block_ids_.size(); + Future<> pending_blocks_completed; + { + std::unique_lock lock(upload_state_->mutex); + pending_blocks_completed = upload_state_->pending_blocks_completed; + } + + return pending_blocks_completed.Then([self = Self()] { + std::unique_lock lock(self->upload_state_->mutex); + return CommitBlockList(self->block_blob_client_, self->upload_state_->block_ids, + self->commit_block_list_options_); + }); + } + + private: + Status AppendCurrentBlock() { + ARROW_ASSIGN_OR_RAISE(auto buf, current_block_->Finish()); + current_block_.reset(); + current_block_size_ = 0; + return AppendBlock(buf); + } + + Status DoWrite(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + if (closed_) { + return Status::Invalid("Operation on closed stream"); + } + + const auto* data_ptr = reinterpret_cast(data); + auto advance_ptr = [this, &data_ptr, &nbytes](const int64_t offset) { + data_ptr += offset; + nbytes -= offset; + pos_ += offset; + content_length_ += offset; + }; + + // Handle case where we have some bytes buffered from prior calls. + if (current_block_size_ > 0) { + // Try to fill current buffer + const int64_t to_copy = + std::min(nbytes, kBlockUploadSizeBytes - current_block_size_); + RETURN_NOT_OK(current_block_->Write(data_ptr, to_copy)); + current_block_size_ += to_copy; + advance_ptr(to_copy); + + // If buffer isn't full, break + if (current_block_size_ < kBlockUploadSizeBytes) { + return Status::OK(); + } + + // Upload current buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + + // We can upload chunks without copying them into a buffer + while (nbytes >= kBlockUploadSizeBytes) { + const auto upload_size = std::min(nbytes, kMaxBlockSizeBytes); + RETURN_NOT_OK(AppendBlock(data_ptr, upload_size)); + advance_ptr(upload_size); + } + + // Buffer remaining bytes + if (nbytes > 0) { + current_block_size_ = nbytes; + + if (current_block_ == nullptr) { + ARROW_ASSIGN_OR_RAISE( + current_block_, + io::BufferOutputStream::Create(kBlockUploadSizeBytes, io_context_.pool())); + } else { + // Re-use the allocation from before. + RETURN_NOT_OK(current_block_->Reset(kBlockUploadSizeBytes, io_context_.pool())); + } + + RETURN_NOT_OK(current_block_->Write(data_ptr, current_block_size_)); + pos_ += current_block_size_; + content_length_ += current_block_size_; + } + + return Status::OK(); + } + + std::string CreateBlock() { + std::unique_lock lock(upload_state_->mutex); + const auto n_block_ids = upload_state_->block_ids.size(); // New block ID must always be distinct from the existing block IDs. Otherwise we // will accidentally replace the content of existing blocks, causing corruption. @@ -1093,36 +1241,106 @@ class ObjectAppendStream final : public io::OutputStream { new_block_id.insert(0, required_padding_digits, '0'); // There is a small risk when appending to a blob created by another client that // `new_block_id` may overlapping with an existing block id. Adding the `-arrow` - // suffix significantly reduces the risk, but does not 100% eliminate it. For example - // if the blob was previously created with one block, with id `00001-arrow` then the - // next block we append will conflict with that, and cause corruption. + // suffix significantly reduces the risk, but does not 100% eliminate it. For + // example if the blob was previously created with one block, with id `00001-arrow` + // then the next block we append will conflict with that, and cause corruption. new_block_id += "-arrow"; new_block_id = Core::Convert::Base64Encode( std::vector(new_block_id.begin(), new_block_id.end())); - try { - block_blob_client_->StageBlock(new_block_id, block_content); - } catch (const Storage::StorageException& exception) { - return ExceptionToStatus( - exception, "StageBlock failed for '", block_blob_client_->GetUrl(), - "' new_block_id: '", new_block_id, - "'. Staging new blocks is fundamental to streaming writes to blob storage."); + upload_state_->block_ids.push_back(new_block_id); + + // We only use the future if we have background writes enabled. Without background + // writes the future is initialized as finished and not mutated any more. + if (background_writes_ && upload_state_->blocks_in_progress++ == 0) { + upload_state_->pending_blocks_completed = Future<>::Make(); } - block_ids_.push_back(new_block_id); - pos_ += nbytes; - content_length_ += nbytes; + + return new_block_id; + } + + Status AppendBlock(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + RETURN_NOT_OK(CheckClosed("append")); + + if (nbytes == 0) { + return Status::OK(); + } + + const auto block_id = CreateBlock(); + + if (background_writes_) { + if (owned_buffer == nullptr) { + ARROW_ASSIGN_OR_RAISE(owned_buffer, AllocateBuffer(nbytes, io_context_.pool())); + memcpy(owned_buffer->mutable_data(), data, nbytes); + } else { + DCHECK_EQ(data, owned_buffer->data()); + DCHECK_EQ(nbytes, owned_buffer->size()); + } + + // The closure keeps the buffer and the upload state alive + auto deferred = [owned_buffer, block_id, block_blob_client = block_blob_client_, + state = upload_state_]() mutable -> Status { + Core::IO::MemoryBodyStream block_content(owned_buffer->data(), + owned_buffer->size()); + + auto status = StageBlock(block_blob_client.get(), block_id, block_content); + HandleUploadOutcome(state, status); + return Status::OK(); + }; + RETURN_NOT_OK(io::internal::SubmitIO(io_context_, std::move(deferred))); + } else { + auto append_data = reinterpret_cast(data); + Core::IO::MemoryBodyStream block_content(append_data, nbytes); + + RETURN_NOT_OK(StageBlock(block_blob_client_.get(), block_id, block_content)); + } + return Status::OK(); } + Status AppendBlock(std::shared_ptr buffer) { + return AppendBlock(buffer->data(), buffer->size(), buffer); + } + + static void HandleUploadOutcome(const std::shared_ptr& state, + const Status& status) { + std::unique_lock lock(state->mutex); + if (!status.ok()) { + state->status &= status; + } + // Notify completion + if (--state->blocks_in_progress == 0) { + auto fut = state->pending_blocks_completed; + lock.unlock(); + fut.MarkFinished(state->status); + } + } + std::shared_ptr block_blob_client_; const io::IOContext io_context_; const AzureLocation location_; + const bool background_writes_; int64_t content_length_ = kNoSize; + std::shared_ptr current_block_; + int64_t current_block_size_ = 0; + bool closed_ = false; bool initialised_ = false; int64_t pos_ = 0; - std::vector block_ids_; + + // This struct is kept alive through background writes to avoid problems + // in the completion handler. + struct UploadState { + std::mutex mutex; + std::vector block_ids; + int64_t blocks_in_progress = 0; + Status status; + Future<> pending_blocks_completed = Future<>::MakeFinished(Status::OK()); + }; + std::shared_ptr upload_state_; + Blobs::CommitBlockListOptions commit_block_list_options_; }; diff --git a/cpp/src/arrow/filesystem/azurefs.h b/cpp/src/arrow/filesystem/azurefs.h index 072b061eeb2a9..ebbe00c4ee784 100644 --- a/cpp/src/arrow/filesystem/azurefs.h +++ b/cpp/src/arrow/filesystem/azurefs.h @@ -112,6 +112,9 @@ struct ARROW_EXPORT AzureOptions { /// This will be ignored if non-empty metadata is passed to OpenOutputStream. std::shared_ptr default_metadata; + /// Whether OutputStream writes will be issued in the background, without blocking. + bool background_writes = true; + private: enum class CredentialKind { kDefault, diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 5ff241b17ff58..9d437d1f83aac 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -39,6 +39,7 @@ #include #include #include +#include #include #include @@ -53,6 +54,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/future.h" #include "arrow/util/io_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" @@ -566,6 +568,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, default_options.dfs_storage_scheme); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kDefault); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriDfsStorage() { @@ -582,6 +585,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, default_options.dfs_storage_scheme); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kDefault); ASSERT_EQ(path, "file_system/dir/file"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriAbfs() { @@ -597,6 +601,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "https"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriAbfss() { @@ -612,6 +617,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "https"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriEnableTls() { @@ -628,6 +634,17 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "http"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); + } + + void TestFromUriDisableBackgroundWrites() { + std::string path; + ASSERT_OK_AND_ASSIGN(auto options, + AzureOptions::FromUri( + "abfs://account:password@127.0.0.1:10000/container/dir/blob?" + "background_writes=false", + &path)); + ASSERT_EQ(options.background_writes, false); } void TestFromUriCredentialDefault() { @@ -773,6 +790,9 @@ TEST_F(TestAzureOptions, FromUriDfsStorage) { TestFromUriDfsStorage(); } TEST_F(TestAzureOptions, FromUriAbfs) { TestFromUriAbfs(); } TEST_F(TestAzureOptions, FromUriAbfss) { TestFromUriAbfss(); } TEST_F(TestAzureOptions, FromUriEnableTls) { TestFromUriEnableTls(); } +TEST_F(TestAzureOptions, FromUriDisableBackgroundWrites) { + TestFromUriDisableBackgroundWrites(); +} TEST_F(TestAzureOptions, FromUriCredentialDefault) { TestFromUriCredentialDefault(); } TEST_F(TestAzureOptions, FromUriCredentialAnonymous) { TestFromUriCredentialAnonymous(); } TEST_F(TestAzureOptions, FromUriCredentialStorageSharedKey) { @@ -929,8 +949,9 @@ class TestAzureFileSystem : public ::testing::Test { void UploadLines(const std::vector& lines, const std::string& path, int total_size) { ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - const auto all_lines = std::accumulate(lines.begin(), lines.end(), std::string("")); - ASSERT_OK(output->Write(all_lines)); + for (auto const& line : lines) { + ASSERT_OK(output->Write(line.data(), line.size())); + } ASSERT_OK(output->Close()); } @@ -1474,6 +1495,162 @@ class TestAzureFileSystem : public ::testing::Test { arrow::fs::AssertFileInfo(fs(), data.Path("dir/file0"), FileType::File); } + void AssertObjectContents(AzureFileSystem* fs, std::string_view path, + std::string_view expected) { + ASSERT_OK_AND_ASSIGN(auto input, fs->OpenInputStream(std::string{path})); + std::string contents; + std::shared_ptr buffer; + do { + ASSERT_OK_AND_ASSIGN(buffer, input->Read(128 * 1024)); + contents.append(buffer->ToString()); + } while (buffer->size() != 0); + + EXPECT_EQ(expected, contents); + } + + void TestOpenOutputStreamSmall() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + const std::string_view expected(PreexistingData::kLoremIpsum); + ASSERT_OK(output->Write(expected)); + ASSERT_OK(output->Close()); + + // Verify we can read the object back. + AssertObjectContents(fs.get(), path, expected); + } + + void TestOpenOutputStreamLarge() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + + // Upload 5 MB, 4 MB und 2 MB and a very small write to test varying sizes + std::vector sizes{5 * 1024 * 1024, 4 * 1024 * 1024, 2 * 1024 * 1024, + 2000}; + + std::vector buffers{}; + char current_char = 'A'; + for (const auto size : sizes) { + buffers.emplace_back(size, current_char++); + } + + auto expected_size = std::int64_t{0}; + for (size_t i = 0; i < buffers.size(); ++i) { + ASSERT_OK(output->Write(buffers[i])); + expected_size += sizes[i]; + ASSERT_EQ(expected_size, output->Tell()); + } + ASSERT_OK(output->Close()); + + AssertObjectContents(fs.get(), path, + buffers[0] + buffers[1] + buffers[2] + buffers[3]); + } + + void TestOpenOutputStreamLargeSingleWrite() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + + constexpr std::int64_t size{12 * 1024 * 1024}; + const std::string large_string(size, 'X'); + + ASSERT_OK(output->Write(large_string)); + ASSERT_EQ(size, output->Tell()); + ASSERT_OK(output->Close()); + + AssertObjectContents(fs.get(), path, large_string); + } + + void TestOpenOutputStreamCloseAsync() { +#if defined(ADDRESS_SANITIZER) || defined(ARROW_VALGRIND) + // This false positive leak is similar to the one pinpointed in the + // have_false_positive_memory_leak_with_generator() comments above, + // though the stack trace is different. It happens when a block list + // is committed from a background thread. + // + // clang-format off + // Direct leak of 968 byte(s) in 1 object(s) allocated from: + // #0 calloc + // #1 (/lib/x86_64-linux-gnu/libxml2.so.2+0xe25a4) + // #2 __xmlDefaultBufferSize + // #3 xmlBufferCreate + // #4 Azure::Storage::_internal::XmlWriter::XmlWriter() + // #5 Azure::Storage::Blobs::_detail::BlockBlobClient::CommitBlockList + // #6 Azure::Storage::Blobs::BlockBlobClient::CommitBlockList + // #7 arrow::fs::(anonymous namespace)::CommitBlockList + // #8 arrow::fs::(anonymous namespace)::ObjectAppendStream::FlushAsync()::'lambda' + // clang-format on + // + // TODO perhaps remove this skip once we can rely on + // https://github.com/Azure/azure-sdk-for-cpp/pull/5767 + // + // Also note that ClickHouse has a workaround for a similar issue: + // https://github.com/ClickHouse/ClickHouse/pull/45796 + if (options_.background_writes) { + GTEST_SKIP() << "False positive memory leak in libxml2 with CloseAsync"; + } +#endif + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + constexpr auto payload = PreexistingData::kLoremIpsum; + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + auto close_fut = stream->CloseAsync(); + + ASSERT_OK(close_fut.MoveResult()); + + AssertObjectContents(fs.get(), path, payload); + } + + void TestOpenOutputStreamCloseAsyncDestructor() { +#if defined(ADDRESS_SANITIZER) || defined(ARROW_VALGRIND) + // See above. + if (options_.background_writes) { + GTEST_SKIP() << "False positive memory leak in libxml2 with CloseAsync"; + } +#endif + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + constexpr auto payload = PreexistingData::kLoremIpsum; + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + // Destructor implicitly closes stream and completes the upload. + // Testing it doesn't matter whether flush is triggered asynchronously + // after CloseAsync or synchronously after stream.reset() since we're just + // checking that the future keeps the stream alive until completion + // rather than segfaulting on a dangling stream. + auto close_fut = stream->CloseAsync(); + stream.reset(); + ASSERT_OK(close_fut.MoveResult()); + + AssertObjectContents(fs.get(), path, payload); + } + + void TestOpenOutputStreamDestructor() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + constexpr auto* payload = "new data"; + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + // Destructor implicitly closes stream and completes the multipart upload. + stream.reset(); + + AssertObjectContents(fs.get(), path, payload); + } + private: using StringMatcher = ::testing::PolymorphicMatcher<::testing::internal::HasSubstrMatcher>; @@ -2704,53 +2881,27 @@ TEST_F(TestAzuriteFileSystem, WriteMetadataHttpHeaders) { ASSERT_EQ("text/plain", content_type); } -TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmall) { - auto data = SetUpPreexistingData(); - const auto path = data.ContainerPath("test-write-object"); - ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - const std::string_view expected(PreexistingData::kLoremIpsum); - ASSERT_OK(output->Write(expected)); - ASSERT_OK(output->Close()); - - // Verify we can read the object back. - ASSERT_OK_AND_ASSIGN(auto input, fs()->OpenInputStream(path)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmallNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamSmall(); +} - std::array inbuf{}; - ASSERT_OK_AND_ASSIGN(auto size, input->Read(inbuf.size(), inbuf.data())); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmall) { TestOpenOutputStreamSmall(); } - EXPECT_EQ(expected, std::string_view(inbuf.data(), size)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamLarge(); } -TEST_F(TestAzuriteFileSystem, OpenOutputStreamLarge) { - auto data = SetUpPreexistingData(); - const auto path = data.ContainerPath("test-write-object"); - ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - std::array sizes{257 * 1024, 258 * 1024, 259 * 1024}; - std::array buffers{ - std::string(sizes[0], 'A'), - std::string(sizes[1], 'B'), - std::string(sizes[2], 'C'), - }; - auto expected = std::int64_t{0}; - for (auto i = 0; i != 3; ++i) { - ASSERT_OK(output->Write(buffers[i])); - expected += sizes[i]; - ASSERT_EQ(expected, output->Tell()); - } - ASSERT_OK(output->Close()); - - // Verify we can read the object back. - ASSERT_OK_AND_ASSIGN(auto input, fs()->OpenInputStream(path)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLarge) { TestOpenOutputStreamLarge(); } - std::string contents; - std::shared_ptr buffer; - do { - ASSERT_OK_AND_ASSIGN(buffer, input->Read(128 * 1024)); - ASSERT_TRUE(buffer); - contents.append(buffer->ToString()); - } while (buffer->size() != 0); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeSingleWriteNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamLargeSingleWrite(); +} - EXPECT_EQ(contents, buffers[0] + buffers[1] + buffers[2]); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeSingleWrite) { + TestOpenOutputStreamLargeSingleWrite(); } TEST_F(TestAzuriteFileSystem, OpenOutputStreamTruncatesExistingFile) { @@ -2820,6 +2971,33 @@ TEST_F(TestAzuriteFileSystem, OpenOutputStreamClosed) { ASSERT_RAISES(Invalid, output->Tell()); } +TEST_F(TestAzuriteFileSystem, OpenOutputStreamCloseAsync) { + TestOpenOutputStreamCloseAsync(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamCloseAsyncNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamCloseAsync(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamAsyncDestructor) { + TestOpenOutputStreamCloseAsyncDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamAsyncDestructorNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamCloseAsyncDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamDestructor) { + TestOpenOutputStreamDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamDestructorNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamDestructor(); +} + TEST_F(TestAzuriteFileSystem, OpenOutputStreamUri) { auto data = SetUpPreexistingData(); const auto path = data.ContainerPath("open-output-stream-uri.txt"); From ffee537d88ab6d26614e2a1e85d4d18152695020 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 21 Aug 2024 14:18:45 +0200 Subject: [PATCH 08/32] GH-42222: [Python] Add bindings for CopyTo on RecordBatch and Array classes (#42223) ### Rationale for this change We have added bindings for the Device and MemoryManager classes (https://github.com/apache/arrow/issues/41126), and as a next step we can expose the functionality to copy a full Array or RecordBatch to a specific memory manager. ### What changes are included in this PR? This adds a `copy_to` method on pyarrow Array and RecordBatch. ### Are these changes tested? Yes * GitHub Issue: #42222 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/array.pxi | 36 ++++++++++++ python/pyarrow/device.pxi | 6 ++ python/pyarrow/includes/libarrow.pxd | 4 ++ python/pyarrow/lib.pxd | 4 ++ python/pyarrow/table.pxi | 35 ++++++++++++ python/pyarrow/tests/test_cuda.py | 82 +++++++++++----------------- python/pyarrow/tests/test_device.py | 26 +++++++++ 7 files changed, 143 insertions(+), 50 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 4c3eb93232634..77d6c9c06d2de 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1702,6 +1702,42 @@ cdef class Array(_PandasConvertible): _append_array_buffers(self.sp_array.get().data().get(), res) return res + def copy_to(self, destination): + """ + Construct a copy of the array with all buffers on destination + device. + + This method recursively copies the array's buffers and those of its + children onto the destination MemoryManager device and returns the + new Array. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + Array + """ + cdef: + shared_ptr[CArray] c_array + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_array = GetResultValue(self.ap.CopyTo(c_memory_manager)) + return pyarrow_wrap_array(c_array) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/device.pxi b/python/pyarrow/device.pxi index 6e6034752085a..26256de62093e 100644 --- a/python/pyarrow/device.pxi +++ b/python/pyarrow/device.pxi @@ -64,6 +64,9 @@ cdef class Device(_Weakrefable): self.init(device) return self + cdef inline shared_ptr[CDevice] unwrap(self) nogil: + return self.device + def __eq__(self, other): if not isinstance(other, Device): return False @@ -130,6 +133,9 @@ cdef class MemoryManager(_Weakrefable): self.init(mm) return self + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil: + return self.memory_manager + def __repr__(self): return "".format( frombytes(self.memory_manager.get().device().get().ToString()) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a54a1db292f70..6f510cfc0c06c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -234,7 +234,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CStatus Validate() const CStatus ValidateFull() const CResult[shared_ptr[CArray]] View(const shared_ptr[CDataType]& type) + CDeviceAllocationType device_type() + CResult[shared_ptr[CArray]] CopyTo(const shared_ptr[CMemoryManager]& to) const shared_ptr[CArray] MakeArray(const shared_ptr[CArrayData]& data) CResult[shared_ptr[CArray]] MakeArrayOfNull( @@ -1027,6 +1029,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CRecordBatch] Slice(int64_t offset) shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length) + CResult[shared_ptr[CRecordBatch]] CopyTo(const shared_ptr[CMemoryManager]& to) const + CResult[shared_ptr[CTensor]] ToTensor(c_bool null_to_nan, c_bool row_major, CMemoryPool* pool) const diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index e3625c1815274..a7c3b496a0045 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -542,6 +542,8 @@ cdef class Device(_Weakrefable): @staticmethod cdef wrap(const shared_ptr[CDevice]& device) + cdef inline shared_ptr[CDevice] unwrap(self) nogil + cdef class MemoryManager(_Weakrefable): cdef: @@ -552,6 +554,8 @@ cdef class MemoryManager(_Weakrefable): @staticmethod cdef wrap(const shared_ptr[CMemoryManager]& mm) + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil + cdef class Buffer(_Weakrefable): cdef: diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8f7c44e55dc8d..6d34c71c9df40 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3569,6 +3569,41 @@ cdef class RecordBatch(_Tabular): row_major, pool)) return pyarrow_wrap_tensor(c_tensor) + def copy_to(self, destination): + """ + Copy the entire RecordBatch to destination device. + + This copies each column of the record batch to create + a new record batch where all underlying buffers for the columns have + been copied to the destination MemoryManager. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] c_batch + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_batch = GetResultValue(self.batch.CopyTo(c_memory_manager)) + return pyarrow_wrap_batch(c_batch) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/tests/test_cuda.py b/python/pyarrow/tests/test_cuda.py index 36b97a6206463..d55be651b1571 100644 --- a/python/pyarrow/tests/test_cuda.py +++ b/python/pyarrow/tests/test_cuda.py @@ -827,21 +827,29 @@ def test_IPC(size): assert p.exitcode == 0 -def _arr_copy_to_host(carr): - # TODO replace below with copy to device when exposed in python - buffers = [] - for cbuf in carr.buffers(): - if cbuf is None: - buffers.append(None) - else: - buf = global_context.foreign_buffer( - cbuf.address, cbuf.size, cbuf - ).copy_to_host() - buffers.append(buf) - - child = pa.Array.from_buffers(carr.type.value_type, 3, buffers[2:]) - new = pa.Array.from_buffers(carr.type, 2, buffers[:2], children=[child]) - return new +def test_copy_to(): + _, buf = make_random_buffer(size=10, target='device') + mm_cuda = buf.memory_manager + + for dest in [mm_cuda, mm_cuda.device]: + arr = pa.array([0, 1, 2]) + arr_cuda = arr.copy_to(dest) + assert not arr_cuda.buffers()[1].is_cpu + assert arr_cuda.buffers()[1].device_type == pa.DeviceAllocationType.CUDA + assert arr_cuda.buffers()[1].device == mm_cuda.device + + arr_roundtrip = arr_cuda.copy_to(pa.default_cpu_memory_manager()) + assert arr_roundtrip.equals(arr) + + batch = pa.record_batch({"col": arr}) + batch_cuda = batch.copy_to(dest) + buf_cuda = batch_cuda["col"].buffers()[1] + assert not buf_cuda.is_cpu + assert buf_cuda.device_type == pa.DeviceAllocationType.CUDA + assert buf_cuda.device == mm_cuda.device + + batch_roundtrip = batch_cuda.copy_to(pa.default_cpu_memory_manager()) + assert batch_roundtrip.equals(batch) def test_device_interface_array(): @@ -856,19 +864,10 @@ def test_device_interface_array(): typ = pa.list_(pa.int32()) arr = pa.array([[1], [2, 42]], type=typ) - # TODO replace below with copy to device when exposed in python - cbuffers = [] - for buf in arr.buffers(): - if buf is None: - cbuffers.append(None) - else: - cbuf = global_context.new_buffer(buf.size) - cbuf.copy_from_host(buf, position=0, nbytes=buf.size) - cbuffers.append(cbuf) - - carr = pa.Array.from_buffers(typ, 2, cbuffers[:2], children=[ - pa.Array.from_buffers(typ.value_type, 3, cbuffers[2:]) - ]) + # copy to device + _, buf = make_random_buffer(size=10, target='device') + mm_cuda = buf.memory_manager + carr = arr.copy_to(mm_cuda) # Type is known up front carr._export_to_c_device(ptr_array) @@ -882,7 +881,7 @@ def test_device_interface_array(): del carr carr_new = pa.Array._import_from_c_device(ptr_array, typ) assert carr_new.type == pa.list_(pa.int32()) - arr_new = _arr_copy_to_host(carr_new) + arr_new = carr_new.copy_to(pa.default_cpu_memory_manager()) assert arr_new.equals(arr) del carr_new @@ -891,15 +890,13 @@ def test_device_interface_array(): pa.Array._import_from_c_device(ptr_array, typ) # Schema is exported and imported at the same time - carr = pa.Array.from_buffers(typ, 2, cbuffers[:2], children=[ - pa.Array.from_buffers(typ.value_type, 3, cbuffers[2:]) - ]) + carr = arr.copy_to(mm_cuda) carr._export_to_c_device(ptr_array, ptr_schema) # Delete and recreate C++ objects from exported pointers del carr carr_new = pa.Array._import_from_c_device(ptr_array, ptr_schema) assert carr_new.type == pa.list_(pa.int32()) - arr_new = _arr_copy_to_host(carr_new) + arr_new = carr_new.copy_to(pa.default_cpu_memory_manager()) assert arr_new.equals(arr) del carr_new @@ -908,21 +905,6 @@ def test_device_interface_array(): pa.Array._import_from_c_device(ptr_array, ptr_schema) -def _batch_copy_to_host(cbatch): - # TODO replace below with copy to device when exposed in python - arrs = [] - for col in cbatch.columns: - buffers = [ - global_context.foreign_buffer(buf.address, buf.size, buf).copy_to_host() - if buf is not None else None - for buf in col.buffers() - ] - new = pa.Array.from_buffers(col.type, len(col), buffers) - arrs.append(new) - - return pa.RecordBatch.from_arrays(arrs, schema=cbatch.schema) - - def test_device_interface_batch_array(): cffi = pytest.importorskip("pyarrow.cffi") ffi = cffi.ffi @@ -949,7 +931,7 @@ def test_device_interface_batch_array(): del cbatch cbatch_new = pa.RecordBatch._import_from_c_device(ptr_array, schema) assert cbatch_new.schema == schema - batch_new = _batch_copy_to_host(cbatch_new) + batch_new = cbatch_new.copy_to(pa.default_cpu_memory_manager()) assert batch_new.equals(batch) del cbatch_new @@ -964,7 +946,7 @@ def test_device_interface_batch_array(): del cbatch cbatch_new = pa.RecordBatch._import_from_c_device(ptr_array, ptr_schema) assert cbatch_new.schema == schema - batch_new = _batch_copy_to_host(cbatch_new) + batch_new = cbatch_new.copy_to(pa.default_cpu_memory_manager()) assert batch_new.equals(batch) del cbatch_new diff --git a/python/pyarrow/tests/test_device.py b/python/pyarrow/tests/test_device.py index 6bdb015be1a95..dc1a51e6d0092 100644 --- a/python/pyarrow/tests/test_device.py +++ b/python/pyarrow/tests/test_device.py @@ -17,6 +17,8 @@ import pyarrow as pa +import pytest + def test_device_memory_manager(): mm = pa.default_cpu_memory_manager() @@ -41,3 +43,27 @@ def test_buffer_device(): assert buf.device.is_cpu assert buf.device == pa.default_cpu_memory_manager().device assert buf.memory_manager.is_cpu + + +def test_copy_to(): + mm = pa.default_cpu_memory_manager() + + arr = pa.array([0, 1, 2]) + batch = pa.record_batch({"col": arr}) + + for dest in [mm, mm.device]: + arr_copied = arr.copy_to(dest) + assert arr_copied.equals(arr) + assert arr_copied.buffers()[1].device == mm.device + assert arr_copied.buffers()[1].address != arr.buffers()[1].address + + batch_copied = batch.copy_to(dest) + assert batch_copied.equals(batch) + assert batch_copied["col"].buffers()[1].device == mm.device + assert batch_copied["col"].buffers()[1].address != arr.buffers()[1].address + + with pytest.raises(TypeError, match="Argument 'destination' has incorrect type"): + arr.copy_to(mm.device.device_type) + + with pytest.raises(TypeError, match="Argument 'destination' has incorrect type"): + batch.copy_to(mm.device.device_type) From f9911ee2ffc62fa946b2e1198bcdd13a757181fe Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 21 Aug 2024 14:37:47 +0200 Subject: [PATCH 09/32] GH-43776: [C++] Add chunked Take benchmarks with a small selection factor (#43772) This should help exercise the performance of chunked Take implementation on more use cases. * GitHub Issue: #43776 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- .../kernels/vector_selection_benchmark.cc | 91 ++++++++++++++++--- 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc index c2a27dfe43488..75affd32560f0 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc @@ -17,6 +17,7 @@ #include "benchmark/benchmark.h" +#include #include #include @@ -42,6 +43,9 @@ struct FilterParams { const double filter_null_proportion; }; +constexpr double kDefaultTakeSelectionFactor = 1.0; +constexpr double kSmallTakeSelectionFactor = 0.05; + std::vector g_data_sizes = {kL2Size}; // The benchmark state parameter references this vector of cases. Test high and @@ -104,14 +108,21 @@ struct TakeBenchmark { benchmark::State& state; RegressionArgs args; random::RandomArrayGenerator rand; + double selection_factor; bool indices_have_nulls; bool monotonic_indices = false; TakeBenchmark(benchmark::State& state, bool indices_have_nulls, bool monotonic_indices = false) + : TakeBenchmark(state, /*selection_factor=*/kDefaultTakeSelectionFactor, + indices_have_nulls, monotonic_indices) {} + + TakeBenchmark(benchmark::State& state, double selection_factor, bool indices_have_nulls, + bool monotonic_indices = false) : state(state), args(state, /*size_is_bytes=*/false), rand(kSeed), + selection_factor(selection_factor), indices_have_nulls(indices_have_nulls), monotonic_indices(monotonic_indices) {} @@ -185,10 +196,10 @@ struct TakeBenchmark { } void Bench(const std::shared_ptr& values) { - double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; - auto indices = - rand.Int32(values->length(), 0, static_cast(values->length() - 1), - indices_null_proportion); + const double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; + const int64_t num_indices = static_cast(selection_factor * values->length()); + auto indices = rand.Int32(num_indices, 0, static_cast(values->length() - 1), + indices_null_proportion); if (monotonic_indices) { auto arg_sorter = *SortIndices(*indices); @@ -198,14 +209,15 @@ struct TakeBenchmark { for (auto _ : state) { ABORT_NOT_OK(Take(values, indices).status()); } - state.SetItemsProcessed(state.iterations() * values->length()); + state.SetItemsProcessed(state.iterations() * num_indices); + state.counters["selection_factor"] = selection_factor; } void BenchChunked(const std::shared_ptr& values, bool chunk_indices_too) { double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; - auto indices = - rand.Int32(values->length(), 0, static_cast(values->length() - 1), - indices_null_proportion); + const int64_t num_indices = static_cast(selection_factor * values->length()); + auto indices = rand.Int32(num_indices, 0, static_cast(values->length() - 1), + indices_null_proportion); if (monotonic_indices) { auto arg_sorter = *SortIndices(*indices); @@ -213,14 +225,26 @@ struct TakeBenchmark { } std::shared_ptr chunked_indices; if (chunk_indices_too) { + // Here we choose for indices chunks to have roughly the same length + // as values chunks, but there may be less of them if selection_factor < 1.0. + // The alternative is to have the same number of chunks, but with a potentially + // much smaller (and irrealistic) length. std::vector> indices_chunks; + // Make sure there are at least two chunks of indices + const auto max_chunk_length = indices->length() / 2 + 1; int64_t offset = 0; for (int i = 0; i < values->num_chunks(); ++i) { - auto chunk = indices->Slice(offset, values->chunk(i)->length()); + const auto chunk_length = std::min(max_chunk_length, values->chunk(i)->length()); + auto chunk = indices->Slice(offset, chunk_length); indices_chunks.push_back(std::move(chunk)); - offset += values->chunk(i)->length(); + offset += chunk_length; + if (offset >= indices->length()) { + break; + } } chunked_indices = std::make_shared(std::move(indices_chunks)); + ARROW_CHECK_EQ(chunked_indices->length(), num_indices); + ARROW_CHECK_GT(chunked_indices->num_chunks(), 1); } if (chunk_indices_too) { @@ -232,7 +256,8 @@ struct TakeBenchmark { ABORT_NOT_OK(Take(values, indices).status()); } } - state.SetItemsProcessed(state.iterations() * values->length()); + state.SetItemsProcessed(state.iterations() * num_indices); + state.counters["selection_factor"] = selection_factor; } }; @@ -432,12 +457,25 @@ static void TakeChunkedChunkedInt64RandomIndicesWithNulls(benchmark::State& stat .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedInt64FewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedInt64MonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedInt64( /*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedInt64FewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedFSBRandomIndicesNoNulls(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false) .ChunkedFSB(/*num_chunks=*/100, /*chunk_indices_too=*/true); @@ -463,11 +501,23 @@ static void TakeChunkedChunkedStringRandomIndicesWithNulls(benchmark::State& sta .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedStringFewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedStringMonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedStringFewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedFlatInt64RandomIndicesNoNulls(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false) .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); @@ -478,12 +528,25 @@ static void TakeChunkedFlatInt64RandomIndicesWithNulls(benchmark::State& state) .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); } +static void TakeChunkedFlatInt64FewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); +} + static void TakeChunkedFlatInt64MonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedInt64( /*num_chunks=*/100, /*chunk_indices_too=*/false); } +static void TakeChunkedFlatInt64FewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/false); +} + void FilterSetArgs(benchmark::internal::Benchmark* bench) { for (int64_t size : g_data_sizes) { for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { @@ -560,18 +623,24 @@ BENCHMARK(TakeStringMonotonicIndices)->Apply(TakeSetArgs); // Chunked values x Chunked indices BENCHMARK(TakeChunkedChunkedInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64FewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64FewMonotonicIndices)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedFSBRandomIndicesNoNulls)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedFSBRandomIndicesWithNulls)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedFSBMonotonicIndices)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedStringRandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedStringRandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringFewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedStringMonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringFewMonotonicIndices)->Apply(TakeSetArgs); // Chunked values x Flat indices BENCHMARK(TakeChunkedFlatInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedFlatInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64FewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedFlatInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64FewMonotonicIndices)->Apply(TakeSetArgs); } // namespace compute } // namespace arrow From f078942ce2df68de8f48c3b4233132133601ca53 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 22 Aug 2024 02:59:04 +1200 Subject: [PATCH 10/32] GH-43141: [C++][Parquet] Replace use of int with int32_t in the internal Parquet encryption APIs (#43413) ### Rationale for this change See #43141 ### What changes are included in this PR? * Changes uses of int to int32_t in the Encryptor and Decryptor APIs, except where interfacing with OpenSSL. * Also change RandBytes to use size_t instead of int and check for overflow. * Check the return code from OpenSSL's Rand_bytes in case there is a failure generating random bytes ### Are these changes tested? Yes, this doesn't change behaviour and is covered by existing tests. ### Are there any user-facing changes? No * GitHub Issue: #43141 Authored-by: Adam Reeve Signed-off-by: Antoine Pitrou --- cpp/src/parquet/column_reader.cc | 4 +- cpp/src/parquet/encryption/crypto_factory.cc | 6 +- .../parquet/encryption/encryption_internal.cc | 251 ++++++++++-------- .../parquet/encryption/encryption_internal.h | 46 ++-- .../encryption/encryption_internal_nossl.cc | 47 ++-- .../encryption/encryption_internal_test.cc | 22 +- .../parquet/encryption/file_key_wrapper.cc | 4 +- .../encryption/internal_file_decryptor.cc | 12 +- .../encryption/internal_file_decryptor.h | 8 +- .../encryption/internal_file_encryptor.cc | 10 +- .../encryption/internal_file_encryptor.h | 6 +- .../encryption/key_toolkit_internal.cc | 2 +- cpp/src/parquet/metadata.cc | 6 +- cpp/src/parquet/thrift_internal.h | 2 +- 14 files changed, 233 insertions(+), 193 deletions(-) diff --git a/cpp/src/parquet/column_reader.cc b/cpp/src/parquet/column_reader.cc index 05ee6a16c5448..60a8a2176b0a8 100644 --- a/cpp/src/parquet/column_reader.cc +++ b/cpp/src/parquet/column_reader.cc @@ -468,8 +468,8 @@ std::shared_ptr SerializedPageReader::NextPage() { // Advance the stream offset PARQUET_THROW_NOT_OK(stream_->Advance(header_size)); - int compressed_len = current_page_header_.compressed_page_size; - int uncompressed_len = current_page_header_.uncompressed_page_size; + int32_t compressed_len = current_page_header_.compressed_page_size; + int32_t uncompressed_len = current_page_header_.uncompressed_page_size; if (compressed_len < 0 || uncompressed_len < 0) { throw ParquetException("Invalid page header"); } diff --git a/cpp/src/parquet/encryption/crypto_factory.cc b/cpp/src/parquet/encryption/crypto_factory.cc index 72506bdc014b6..56069d559771c 100644 --- a/cpp/src/parquet/encryption/crypto_factory.cc +++ b/cpp/src/parquet/encryption/crypto_factory.cc @@ -72,8 +72,7 @@ std::shared_ptr CryptoFactory::GetFileEncryptionProper int dek_length = dek_length_bits / 8; std::string footer_key(dek_length, '\0'); - RandBytes(reinterpret_cast(&footer_key[0]), - static_cast(footer_key.size())); + RandBytes(reinterpret_cast(footer_key.data()), footer_key.size()); std::string footer_key_metadata = key_wrapper.GetEncryptionKeyMetadata(footer_key, footer_key_id, true); @@ -148,8 +147,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties } std::string column_key(dek_length, '\0'); - RandBytes(reinterpret_cast(&column_key[0]), - static_cast(column_key.size())); + RandBytes(reinterpret_cast(column_key.data()), column_key.size()); std::string column_key_key_metadata = key_wrapper->GetEncryptionKeyMetadata(column_key, column_key_id, false); diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index 99d1707f4a8d4..a0d9367b619c6 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -18,6 +18,7 @@ #include "parquet/encryption/encryption_internal.h" #include +#include #include #include @@ -36,10 +37,10 @@ using parquet::ParquetException; namespace parquet::encryption { -constexpr int kGcmMode = 0; -constexpr int kCtrMode = 1; -constexpr int kCtrIvLength = 16; -constexpr int kBufferSizeLength = 4; +constexpr int32_t kGcmMode = 0; +constexpr int32_t kCtrMode = 1; +constexpr int32_t kCtrIvLength = 16; +constexpr int32_t kBufferSizeLength = 4; #define ENCRYPT_INIT(CTX, ALG) \ if (1 != EVP_EncryptInit_ex(CTX, ALG, nullptr, nullptr, nullptr)) { \ @@ -53,17 +54,17 @@ constexpr int kBufferSizeLength = 4; class AesEncryptor::AesEncryptorImpl { public: - explicit AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesEncryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length); ~AesEncryptorImpl() { WipeOut(); } - int Encrypt(span plaintext, span key, - span aad, span ciphertext); + int32_t Encrypt(span plaintext, span key, + span aad, span ciphertext); - int SignedFooterEncrypt(span footer, span key, - span aad, span nonce, - span encrypted_footer); + int32_t SignedFooterEncrypt(span footer, span key, + span aad, span nonce, + span encrypted_footer); void WipeOut() { if (nullptr != ctx_) { EVP_CIPHER_CTX_free(ctx_); @@ -89,21 +90,22 @@ class AesEncryptor::AesEncryptorImpl { private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; + int32_t aes_mode_; + int32_t key_length_; + int32_t ciphertext_size_delta_; + int32_t length_buffer_length_; - int GcmEncrypt(span plaintext, span key, - span nonce, span aad, - span ciphertext); + int32_t GcmEncrypt(span plaintext, span key, + span nonce, span aad, + span ciphertext); - int CtrEncrypt(span plaintext, span key, - span nonce, span ciphertext); + int32_t CtrEncrypt(span plaintext, span key, + span nonce, span ciphertext); }; -AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { openssl::EnsureInitialized(); ctx_ = nullptr; @@ -151,11 +153,9 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int } } -int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(span footer, - span key, - span aad, - span nonce, - span encrypted_footer) { +int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( + span footer, span key, span aad, + span nonce, span encrypted_footer) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -176,10 +176,10 @@ int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(span foot return GcmEncrypt(footer, key, nonce, aad, encrypted_footer); } -int AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, - span key, - span aad, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, + span key, + span aad, + span ciphertext) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -205,13 +205,13 @@ int AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, return CtrEncrypt(plaintext, key, nonce, ciphertext); } -int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, - span key, - span nonce, - span aad, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, + span key, + span nonce, + span aad, + span ciphertext) { int len; - int ciphertext_len; + int32_t ciphertext_len; std::array tag{}; @@ -227,12 +227,22 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, } // Setting additional authenticated data + if (aad.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "AAD size " << aad.size() << " overflows int"; + throw ParquetException(ss.str()); + } if ((!aad.empty()) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad.data(), static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } // Encryption + if (plaintext.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Plaintext size " << plaintext.size() << " overflows int"; + throw ParquetException(ss.str()); + } if (1 != EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, &len, plaintext.data(), static_cast(plaintext.size()))) { @@ -256,7 +266,7 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, } // Copying the buffer size, nonce and tag to ciphertext - int buffer_size = kNonceLength + ciphertext_len + kGcmTagLength; + int32_t buffer_size = kNonceLength + ciphertext_len + kGcmTagLength; if (length_buffer_length_ > 0) { ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); @@ -271,12 +281,12 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, return length_buffer_length_ + buffer_size; } -int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, - span key, - span nonce, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, + span key, + span nonce, + span ciphertext) { int len; - int ciphertext_len; + int32_t ciphertext_len; if (nonce.size() != static_cast(kNonceLength)) { std::stringstream ss; @@ -298,6 +308,11 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, } // Encryption + if (plaintext.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Plaintext size " << plaintext.size() << " overflows int"; + throw ParquetException(ss.str()); + } if (1 != EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, &len, plaintext.data(), static_cast(plaintext.size()))) { @@ -316,7 +331,7 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, ciphertext_len += len; // Copying the buffer size and nonce to ciphertext - int buffer_size = kNonceLength + ciphertext_len; + int32_t buffer_size = kNonceLength + ciphertext_len; if (length_buffer_length_ > 0) { ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); @@ -331,9 +346,11 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(span footer, span key, - span aad, span nonce, - span encrypted_footer) { +int32_t AesEncryptor::SignedFooterEncrypt(span footer, + span key, + span aad, + span nonce, + span encrypted_footer) { return impl_->SignedFooterEncrypt(footer, key, aad, nonce, encrypted_footer); } @@ -343,25 +360,25 @@ int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { return impl_->CiphertextLength(plaintext_len); } -int AesEncryptor::Encrypt(span plaintext, span key, - span aad, span ciphertext) { +int32_t AesEncryptor::Encrypt(span plaintext, span key, + span aad, span ciphertext) { return impl_->Encrypt(plaintext, key, aad, ciphertext); } -AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length) : impl_{std::unique_ptr( new AesEncryptorImpl(alg_id, key_len, metadata, write_length))} {} class AesDecryptor::AesDecryptorImpl { public: - explicit AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesDecryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length); ~AesDecryptorImpl() { WipeOut(); } - int Decrypt(span ciphertext, span key, - span aad, span plaintext); + int32_t Decrypt(span ciphertext, span key, + span aad, span plaintext); void WipeOut() { if (nullptr != ctx_) { @@ -370,7 +387,7 @@ class AesDecryptor::AesDecryptorImpl { } } - [[nodiscard]] int PlaintextLength(int ciphertext_len) const { + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const { if (ciphertext_len < ciphertext_size_delta_) { std::stringstream ss; ss << "Ciphertext length " << ciphertext_len << " is invalid, expected at least " @@ -380,12 +397,13 @@ class AesDecryptor::AesDecryptorImpl { return ciphertext_len - ciphertext_size_delta_; } - [[nodiscard]] int CiphertextLength(int plaintext_len) const { + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const { if (plaintext_len < 0) { std::stringstream ss; ss << "Negative plaintext length " << plaintext_len; throw ParquetException(ss.str()); - } else if (plaintext_len > std::numeric_limits::max() - ciphertext_size_delta_) { + } else if (plaintext_len > + std::numeric_limits::max() - ciphertext_size_delta_) { std::stringstream ss; ss << "Plaintext length " << plaintext_len << " plus ciphertext size delta " << ciphertext_size_delta_ << " overflows int32"; @@ -396,24 +414,24 @@ class AesDecryptor::AesDecryptorImpl { private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; + int32_t aes_mode_; + int32_t key_length_; + int32_t ciphertext_size_delta_; + int32_t length_buffer_length_; /// Get the actual ciphertext length, inclusive of the length buffer length, /// and validate that the provided buffer size is large enough. - [[nodiscard]] int GetCiphertextLength(span ciphertext) const; + [[nodiscard]] int32_t GetCiphertextLength(span ciphertext) const; - int GcmDecrypt(span ciphertext, span key, - span aad, span plaintext); + int32_t GcmDecrypt(span ciphertext, span key, + span aad, span plaintext); - int CtrDecrypt(span ciphertext, span key, - span plaintext); + int32_t CtrDecrypt(span ciphertext, span key, + span plaintext); }; -int AesDecryptor::Decrypt(span ciphertext, span key, - span aad, span plaintext) { +int32_t AesDecryptor::Decrypt(span ciphertext, span key, + span aad, span plaintext) { return impl_->Decrypt(ciphertext, key, aad, plaintext); } @@ -421,8 +439,9 @@ void AesDecryptor::WipeOut() { impl_->WipeOut(); } AesDecryptor::~AesDecryptor() {} -AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, - bool metadata, bool contains_length) { +AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool contains_length) { openssl::EnsureInitialized(); ctx_ = nullptr; @@ -469,13 +488,14 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int } } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata) { return Make(alg_id, key_len, metadata, true /*write_length*/); } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) { std::stringstream ss; ss << "Crypto algorithm " << alg_id << " is not supported"; @@ -485,13 +505,13 @@ std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int return std::make_unique(alg_id, key_len, metadata, write_length); } -AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length) : impl_{std::unique_ptr( new AesDecryptorImpl(alg_id, key_len, metadata, contains_length))} {} std::shared_ptr AesDecryptor::Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors) { if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) { std::stringstream ss; @@ -506,15 +526,15 @@ std::shared_ptr AesDecryptor::Make( return decryptor; } -int AesDecryptor::PlaintextLength(int ciphertext_len) const { +int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const { return impl_->PlaintextLength(ciphertext_len); } -int AesDecryptor::CiphertextLength(int plaintext_len) const { +int32_t AesDecryptor::CiphertextLength(int32_t plaintext_len) const { return impl_->CiphertextLength(plaintext_len); } -int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( +int32_t AesDecryptor::AesDecryptorImpl::GetCiphertextLength( span ciphertext) const { if (length_buffer_length_ > 0) { // Note: length_buffer_length_ must be either 0 or kBufferSizeLength @@ -533,10 +553,11 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( (static_cast(ciphertext[0])); if (written_ciphertext_len > - static_cast(std::numeric_limits::max() - length_buffer_length_)) { + static_cast(std::numeric_limits::max() - + length_buffer_length_)) { std::stringstream ss; ss << "Written ciphertext length " << written_ciphertext_len - << " plus length buffer length " << length_buffer_length_ << " overflows int"; + << " plus length buffer length " << length_buffer_length_ << " overflows int32"; throw ParquetException(ss.str()); } else if (ciphertext.size() < static_cast(written_ciphertext_len) + length_buffer_length_) { @@ -548,28 +569,28 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( throw ParquetException(ss.str()); } - return static_cast(written_ciphertext_len) + length_buffer_length_; + return static_cast(written_ciphertext_len) + length_buffer_length_; } else { - if (ciphertext.size() > static_cast(std::numeric_limits::max())) { + if (ciphertext.size() > static_cast(std::numeric_limits::max())) { std::stringstream ss; - ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int"; + ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int32"; throw ParquetException(ss.str()); } - return static_cast(ciphertext.size()); + return static_cast(ciphertext.size()); } } -int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, - span key, - span aad, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, + span key, + span aad, + span plaintext) { int len; - int plaintext_len; + int32_t plaintext_len; std::array tag{}; std::array nonce{}; - int ciphertext_len = GetCiphertextLength(ciphertext); + int32_t ciphertext_len = GetCiphertextLength(ciphertext); if (plaintext.size() < static_cast(ciphertext_len) - ciphertext_size_delta_) { std::stringstream ss; @@ -597,16 +618,22 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, } // Setting additional authenticated data + if (aad.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "AAD size " << aad.size() << " overflows int"; + throw ParquetException(ss.str()); + } if ((!aad.empty()) && (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad.data(), static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } // Decryption - if (!EVP_DecryptUpdate( - ctx_, plaintext.data(), &len, - ciphertext.data() + length_buffer_length_ + kNonceLength, - ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength)) { + int decryption_length = + ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength; + if (!EVP_DecryptUpdate(ctx_, plaintext.data(), &len, + ciphertext.data() + length_buffer_length_ + kNonceLength, + decryption_length)) { throw ParquetException("Failed decryption update"); } @@ -626,15 +653,15 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, return plaintext_len; } -int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, - span key, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, + span key, + span plaintext) { int len; - int plaintext_len; + int32_t plaintext_len; std::array iv{}; - int ciphertext_len = GetCiphertextLength(ciphertext); + int32_t ciphertext_len = GetCiphertextLength(ciphertext); if (plaintext.size() < static_cast(ciphertext_len) - ciphertext_size_delta_) { std::stringstream ss; @@ -665,9 +692,10 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, } // Decryption + int decryption_length = ciphertext_len - length_buffer_length_ - kNonceLength; if (!EVP_DecryptUpdate(ctx_, plaintext.data(), &len, ciphertext.data() + length_buffer_length_ + kNonceLength, - ciphertext_len - length_buffer_length_ - kNonceLength)) { + decryption_length)) { throw ParquetException("Failed decryption update"); } @@ -682,10 +710,10 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, return plaintext_len; } -int AesDecryptor::AesDecryptorImpl::Decrypt(span ciphertext, - span key, - span aad, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::Decrypt(span ciphertext, + span key, + span aad, + span plaintext) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -758,9 +786,22 @@ void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) { std::memcpy(AAD->data() + AAD->length() - 2, page_ordinal_bytes.data(), 2); } -void RandBytes(unsigned char* buf, int num) { +void RandBytes(unsigned char* buf, size_t num) { + if (num > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Length " << num << " for RandBytes overflows int"; + throw ParquetException(ss.str()); + } openssl::EnsureInitialized(); - RAND_bytes(buf, num); + int status = RAND_bytes(buf, static_cast(num)); + if (status != 1) { + const auto error_code = ERR_get_error(); + char buffer[256]; + ERR_error_string_n(error_code, buffer, sizeof(buffer)); + std::stringstream ss; + ss << "Failed to generate random bytes: " << buffer; + throw ParquetException(ss.str()); + } } void EnsureBackendInitialized() { openssl::EnsureInitialized(); } diff --git a/cpp/src/parquet/encryption/encryption_internal.h b/cpp/src/parquet/encryption/encryption_internal.h index c874b137ad1ad..d79ff56ad49be 100644 --- a/cpp/src/parquet/encryption/encryption_internal.h +++ b/cpp/src/parquet/encryption/encryption_internal.h @@ -29,8 +29,8 @@ using parquet::ParquetCipher; namespace parquet::encryption { -constexpr int kGcmTagLength = 16; -constexpr int kNonceLength = 12; +constexpr int32_t kGcmTagLength = 16; +constexpr int32_t kNonceLength = 12; // Module types constexpr int8_t kFooter = 0; @@ -49,13 +49,13 @@ class PARQUET_EXPORT AesEncryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. /// If write_length is true, prepend ciphertext length to the ciphertext - explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length = true); - static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + static std::unique_ptr Make(ParquetCipher::type alg_id, int32_t key_len, bool metadata); - static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + static std::unique_ptr Make(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length); ~AesEncryptor(); @@ -65,17 +65,17 @@ class PARQUET_EXPORT AesEncryptor { /// Encrypts plaintext with the key and aad. Key length is passed only for validation. /// If different from value in constructor, exception will be thrown. - int Encrypt(::arrow::util::span plaintext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span ciphertext); + int32_t Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext); /// Encrypts plaintext footer, in order to compute footer signature (tag). - int SignedFooterEncrypt(::arrow::util::span footer, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span nonce, - ::arrow::util::span encrypted_footer); + int32_t SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer); void WipeOut(); @@ -90,7 +90,7 @@ class PARQUET_EXPORT AesDecryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. /// If contains_length is true, expect ciphertext length prepended to the ciphertext - explicit AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length = true); /// \brief Factory function to create an AesDecryptor @@ -102,26 +102,26 @@ class PARQUET_EXPORT AesDecryptor { /// out when decryption is finished /// \return shared pointer to a new AesDecryptor static std::shared_ptr Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors); ~AesDecryptor(); void WipeOut(); /// The size of the plaintext, for this cipher and the specified ciphertext length. - [[nodiscard]] int PlaintextLength(int ciphertext_len) const; + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const; /// The size of the ciphertext, for this cipher and the specified plaintext length. - [[nodiscard]] int CiphertextLength(int plaintext_len) const; + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const; /// Decrypts ciphertext with the key and aad. Key length is passed only for /// validation. If different from value in constructor, exception will be thrown. /// The caller is responsible for ensuring that the plaintext buffer is at least as /// large as PlaintextLength(ciphertext_len). - int Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span plaintext); + int32_t Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span plaintext); private: // PIMPL Idiom @@ -139,7 +139,7 @@ std::string CreateFooterAad(const std::string& aad_prefix_bytes); void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD); // Wraps OpenSSL RAND_bytes function -void RandBytes(unsigned char* buf, int num); +void RandBytes(unsigned char* buf, size_t num); // Ensure OpenSSL is initialized. // diff --git a/cpp/src/parquet/encryption/encryption_internal_nossl.cc b/cpp/src/parquet/encryption/encryption_internal_nossl.cc index 2cce83915d7e5..2a8162ed3964b 100644 --- a/cpp/src/parquet/encryption/encryption_internal_nossl.cc +++ b/cpp/src/parquet/encryption/encryption_internal_nossl.cc @@ -29,11 +29,11 @@ class AesEncryptor::AesEncryptorImpl {}; AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(::arrow::util::span footer, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span nonce, - ::arrow::util::span encrypted_footer) { +int32_t AesEncryptor::SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer) { ThrowOpenSSLRequiredException(); return -1; } @@ -45,25 +45,25 @@ int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { return -1; } -int AesEncryptor::Encrypt(::arrow::util::span plaintext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span ciphertext) { +int32_t AesEncryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext) { ThrowOpenSSLRequiredException(); return -1; } -AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length) { ThrowOpenSSLRequiredException(); } class AesDecryptor::AesDecryptorImpl {}; -int AesDecryptor::Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span plaintext) { +int32_t AesDecryptor::Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span plaintext) { ThrowOpenSSLRequiredException(); return -1; } @@ -72,36 +72,37 @@ void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); } AesDecryptor::~AesDecryptor() {} -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata) { ThrowOpenSSLRequiredException(); return NULLPTR; } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { ThrowOpenSSLRequiredException(); return NULLPTR; } -AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length) { ThrowOpenSSLRequiredException(); } std::shared_ptr AesDecryptor::Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors) { ThrowOpenSSLRequiredException(); return NULLPTR; } -int AesDecryptor::PlaintextLength(int ciphertext_len) const { +int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const { ThrowOpenSSLRequiredException(); return -1; } -int AesDecryptor::CiphertextLength(int plaintext_len) const { +int32_t AesDecryptor::CiphertextLength(int32_t plaintext_len) const { ThrowOpenSSLRequiredException(); return -1; } @@ -122,7 +123,7 @@ void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) { ThrowOpenSSLRequiredException(); } -void RandBytes(unsigned char* buf, int num) { ThrowOpenSSLRequiredException(); } +void RandBytes(unsigned char* buf, size_t num) { ThrowOpenSSLRequiredException(); } void EnsureBackendInitialized() {} diff --git a/cpp/src/parquet/encryption/encryption_internal_test.cc b/cpp/src/parquet/encryption/encryption_internal_test.cc index 22e14663ea81f..bf6607e32877d 100644 --- a/cpp/src/parquet/encryption/encryption_internal_test.cc +++ b/cpp/src/parquet/encryption/encryption_internal_test.cc @@ -41,22 +41,22 @@ class TestAesEncryption : public ::testing::Test { encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), - str2span(aad_), ciphertext); + int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); ASSERT_EQ(ciphertext_length, expected_ciphertext_len); AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); - int plaintext_length = + int32_t plaintext_length = decryptor.Decrypt(ciphertext, str2span(key_), str2span(aad_), decrypted_text); std::string decrypted_text_str(decrypted_text.begin(), decrypted_text.end()); - ASSERT_EQ(plaintext_length, static_cast(plain_text_.size())); + ASSERT_EQ(plaintext_length, static_cast(plain_text_.size())); ASSERT_EQ(plaintext_length, expected_plaintext_length); ASSERT_EQ(decrypted_text_str, plain_text_); } @@ -68,10 +68,10 @@ class TestAesEncryption : public ::testing::Test { AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); // Create ciphertext of all zeros, so the ciphertext length will be read as zero - const int ciphertext_length = 100; + constexpr int32_t ciphertext_length = 100; std::vector ciphertext(ciphertext_length, '\0'); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); EXPECT_THROW( @@ -89,12 +89,12 @@ class TestAesEncryption : public ::testing::Test { encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), - str2span(aad_), ciphertext); + int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); ::arrow::util::span truncated_ciphertext(ciphertext.data(), @@ -105,7 +105,7 @@ class TestAesEncryption : public ::testing::Test { } private: - int key_length_ = 0; + int32_t key_length_ = 0; std::string key_; std::string aad_; std::string plain_text_; diff --git a/cpp/src/parquet/encryption/file_key_wrapper.cc b/cpp/src/parquet/encryption/file_key_wrapper.cc index 032ae45821a68..8ce563e60d752 100644 --- a/cpp/src/parquet/encryption/file_key_wrapper.cc +++ b/cpp/src/parquet/encryption/file_key_wrapper.cc @@ -112,10 +112,10 @@ std::string FileKeyWrapper::GetEncryptionKeyMetadata(const std::string& data_key KeyEncryptionKey FileKeyWrapper::CreateKeyEncryptionKey( const std::string& master_key_id) { std::string kek_bytes(kKeyEncryptionKeyLength, '\0'); - RandBytes(reinterpret_cast(&kek_bytes[0]), kKeyEncryptionKeyLength); + RandBytes(reinterpret_cast(kek_bytes.data()), kKeyEncryptionKeyLength); std::string kek_id(kKeyEncryptionKeyIdLength, '\0'); - RandBytes(reinterpret_cast(&kek_id[0]), kKeyEncryptionKeyIdLength); + RandBytes(reinterpret_cast(kek_id.data()), kKeyEncryptionKeyIdLength); // Encrypt KEK with Master key std::string encoded_wrapped_kek = kms_client_->WrapKey(kek_bytes, master_key_id); diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.cc b/cpp/src/parquet/encryption/internal_file_decryptor.cc index fae5ce1f7a809..53a2f8c02168b 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_decryptor.cc @@ -33,16 +33,16 @@ Decryptor::Decryptor(std::shared_ptr aes_decryptor, aad_(aad), pool_(pool) {} -int Decryptor::PlaintextLength(int ciphertext_len) const { +int32_t Decryptor::PlaintextLength(int32_t ciphertext_len) const { return aes_decryptor_->PlaintextLength(ciphertext_len); } -int Decryptor::CiphertextLength(int plaintext_len) const { +int32_t Decryptor::CiphertextLength(int32_t plaintext_len) const { return aes_decryptor_->CiphertextLength(plaintext_len); } -int Decryptor::Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span plaintext) { +int32_t Decryptor::Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span plaintext) { return aes_decryptor_->Decrypt(ciphertext, str2span(key_), str2span(aad_), plaintext); } @@ -143,7 +143,7 @@ std::shared_ptr InternalFileDecryptor::GetFooterDecryptor( // Create both data and metadata decryptors to avoid redundant retrieval of key // from the key_retriever. - int key_len = static_cast(footer_key.size()); + auto key_len = static_cast(footer_key.size()); std::shared_ptr aes_metadata_decryptor; std::shared_ptr aes_data_decryptor; @@ -197,7 +197,7 @@ std::shared_ptr InternalFileDecryptor::GetColumnDecryptor( throw HiddenColumnException("HiddenColumnException, path=" + column_path); } - int key_len = static_cast(column_key.size()); + auto key_len = static_cast(column_key.size()); std::lock_guard lock(mutex_); auto aes_decryptor = encryption::AesDecryptor::Make(algorithm_, key_len, metadata, &all_decryptors_); diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.h b/cpp/src/parquet/encryption/internal_file_decryptor.h index 8af3587acf884..08423de7fe920 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.h +++ b/cpp/src/parquet/encryption/internal_file_decryptor.h @@ -45,10 +45,10 @@ class PARQUET_EXPORT Decryptor { void UpdateAad(const std::string& aad) { aad_ = aad; } ::arrow::MemoryPool* pool() { return pool_; } - [[nodiscard]] int PlaintextLength(int ciphertext_len) const; - [[nodiscard]] int CiphertextLength(int plaintext_len) const; - int Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span plaintext); + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const; + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const; + int32_t Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span plaintext); private: std::shared_ptr aes_decryptor_; diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.cc b/cpp/src/parquet/encryption/internal_file_encryptor.cc index 285c2100be813..94094e6aca228 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_encryptor.cc @@ -35,8 +35,8 @@ int32_t Encryptor::CiphertextLength(int64_t plaintext_len) const { return aes_encryptor_->CiphertextLength(plaintext_len); } -int Encryptor::Encrypt(::arrow::util::span plaintext, - ::arrow::util::span ciphertext) { +int32_t Encryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext) { return aes_encryptor_->Encrypt(plaintext, str2span(key_), str2span(aad_), ciphertext); } @@ -143,7 +143,7 @@ InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor( return encryptor; } -int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const { +int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int32_t key_len) const { if (key_len == 16) return 0; else if (key_len == 24) @@ -155,7 +155,7 @@ int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const { encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( ParquetCipher::type algorithm, size_t key_size) { - int key_len = static_cast(key_size); + auto key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (meta_encryptor_[index] == nullptr) { meta_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, true); @@ -165,7 +165,7 @@ encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor( ParquetCipher::type algorithm, size_t key_size) { - int key_len = static_cast(key_size); + auto key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (data_encryptor_[index] == nullptr) { data_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, false); diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.h b/cpp/src/parquet/encryption/internal_file_encryptor.h index 91b6e9fe5aa2f..5a3d743ce5365 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.h +++ b/cpp/src/parquet/encryption/internal_file_encryptor.h @@ -45,8 +45,8 @@ class PARQUET_EXPORT Encryptor { [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const; - int Encrypt(::arrow::util::span plaintext, - ::arrow::util::span ciphertext); + int32_t Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext); bool EncryptColumnMetaData( bool encrypted_footer, @@ -103,7 +103,7 @@ class InternalFileEncryptor { encryption::AesEncryptor* GetDataAesEncryptor(ParquetCipher::type algorithm, size_t key_len); - int MapKeyLenToEncryptorArrayIndex(int key_len) const; + int MapKeyLenToEncryptorArrayIndex(int32_t key_len) const; }; } // namespace parquet diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc b/cpp/src/parquet/encryption/key_toolkit_internal.cc index 5d7925aa0318f..89a52a2bcd632 100644 --- a/cpp/src/parquet/encryption/key_toolkit_internal.cc +++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc @@ -53,7 +53,7 @@ std::string DecryptKeyLocally(const std::string& encoded_encrypted_key, static_cast(master_key.size()), false, false /*contains_length*/); - int decrypted_key_len = + int32_t decrypted_key_len = key_decryptor.PlaintextLength(static_cast(encrypted_key.size())); std::string decrypted_key(decrypted_key_len, '\0'); ::arrow::util::span decrypted_key_span( diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 4f2aa6e37328c..423154f8641e5 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -751,7 +751,7 @@ class FileMetaData::FileMetaDataImpl { std::shared_ptr encrypted_buffer = AllocateBuffer( file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len)); - uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( + int32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( serialized_data_span, str2span(key), str2span(aad), nonce, encrypted_buffer->mutable_span_as()); // Delete AES encryptor object. It was created only to verify the footer signature. @@ -799,7 +799,7 @@ class FileMetaData::FileMetaDataImpl { // encrypt the footer key std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); - int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); + int32_t encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); // write unencrypted footer PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len)); @@ -1672,7 +1672,7 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { serialized_len); std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); - int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); + int32_t encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); const char* temp = const_cast(reinterpret_cast(encrypted_data.data())); diff --git a/cpp/src/parquet/thrift_internal.h b/cpp/src/parquet/thrift_internal.h index b21b0e07afba2..e7bfd434c81a8 100644 --- a/cpp/src/parquet/thrift_internal.h +++ b/cpp/src/parquet/thrift_internal.h @@ -530,7 +530,7 @@ class ThriftSerializer { auto cipher_buffer = AllocateBuffer(encryptor->pool(), encryptor->CiphertextLength(out_length)); ::arrow::util::span out_span(out_buffer, out_length); - int cipher_buffer_len = + int32_t cipher_buffer_len = encryptor->Encrypt(out_span, cipher_buffer->mutable_span_as()); PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len)); From 6a1d0520974355a749557c993841732d4fcf894c Mon Sep 17 00:00:00 2001 From: Devin Smith Date: Wed, 21 Aug 2024 18:12:45 -0700 Subject: [PATCH 11/32] GH-43717: [Java][FlightSQL] Add all ActionTypes to FlightSqlUtils.FLIGHT_SQL_ACTIONS (#43718) This adds all of the FlightSQL ActionTypes to FlightSqlUtils.FLIGHT_SQL_ACTIONS * GitHub Issue: #43717 Authored-by: Devin Smith Signed-off-by: David Li --- .../org/apache/arrow/flight/sql/FlightSqlUtils.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index 9bb95047691ae..9e13e57d66c65 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -82,7 +82,15 @@ public final class FlightSqlUtils { + "Response Message: N/A"); public static final List FLIGHT_SQL_ACTIONS = - ImmutableList.of(FLIGHT_SQL_CREATE_PREPARED_STATEMENT, FLIGHT_SQL_CLOSE_PREPARED_STATEMENT); + ImmutableList.of( + FLIGHT_SQL_BEGIN_SAVEPOINT, + FLIGHT_SQL_BEGIN_TRANSACTION, + FLIGHT_SQL_CREATE_PREPARED_STATEMENT, + FLIGHT_SQL_CLOSE_PREPARED_STATEMENT, + FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN, + FLIGHT_SQL_CANCEL_QUERY, + FLIGHT_SQL_END_SAVEPOINT, + FLIGHT_SQL_END_TRANSACTION); /** * Helper to parse {@link com.google.protobuf.Any} objects to the specific protobuf object. From 2e83aa63d95a6fa380efdd5e5cb720a3154f9c93 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 22 Aug 2024 09:57:02 +0200 Subject: [PATCH 12/32] GH-43690: [Python][CI] Simplify python/requirements-wheel-test.txt file (#43691) ### Rationale for this change The current [requirements-wheel-test.txt](https://github.com/apache/arrow/blob/7c8909a144f2e8d593dc8ad363ac95b2865b04ca/python/requirements-wheel-test.txt) file has quite complex and detailed version pinning, varying per architecture. I think this can be simplified because we just want to test with some older version of numpy and pandas (and the exact version isn't that important). * GitHub Issue: #43690 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/requirements-wheel-test.txt | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index 46bedc13ba1a7..c7ff63e339575 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -5,22 +5,12 @@ pytest pytz tzdata; sys_platform == 'win32' -numpy==1.21.3; platform_system == "Linux" and platform_machine == "aarch64" and python_version < "3.11" -numpy==1.23.4; python_version == "3.11" -numpy==1.26.0; python_version >= "3.12" -numpy==1.19.5; platform_system == "Linux" and platform_machine != "aarch64" and python_version < "3.9" -numpy==1.21.3; platform_system == "Linux" and platform_machine != "aarch64" and python_version >= "3.9" and python_version < "3.11" -numpy==1.21.3; platform_system == "Darwin" and platform_machine == "arm64" and python_version < "3.11" -numpy==1.19.5; platform_system == "Darwin" and platform_machine != "arm64" and python_version < "3.9" -numpy==1.21.3; platform_system == "Darwin" and platform_machine != "arm64" and python_version >= "3.9" and python_version < "3.11" -numpy==1.19.5; platform_system == "Windows" and python_version < "3.9" -numpy==1.21.3; platform_system == "Windows" and python_version >= "3.9" and python_version < "3.11" +# We generally test with the oldest numpy version that supports a given Python +# version. However, there is no need to make this strictly the oldest version, +# so it can be broadened to have a single version specification across platforms. +# (`~=x.y.z` specifies a compatible release as `>=x.y.z, == x.y.*`) +numpy~=1.21.3; python_version < "3.11" +numpy~=1.23.2; python_version == "3.11" +numpy~=1.26.0; python_version == "3.12" -pandas<1.1.0; platform_system == "Linux" and platform_machine != "aarch64" and python_version < "3.8" -pandas; platform_system == "Linux" and platform_machine != "aarch64" and python_version >= "3.8" -pandas; platform_system == "Linux" and platform_machine == "aarch64" -pandas<1.1.0; platform_system == "Darwin" and platform_machine != "arm64" and python_version < "3.8" -pandas; platform_system == "Darwin" and platform_machine != "arm64" and python_version >= "3.8" -pandas; platform_system == "Darwin" and platform_machine == "arm64" -pandas<1.1.0; platform_system == "Windows" and python_version < "3.8" -pandas; platform_system == "Windows" and python_version >= "3.8" +pandas From fc54eadb72791288fc9681bbcc6c8a9d8d6fff1d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 22 Aug 2024 11:28:01 +0200 Subject: [PATCH 13/32] GH-43785: [Python][CI] Correct PARQUET_TEST_DATA path in wheel tests (#43786) ### Rationale for this change Starting with https://github.com/apache/arrow/pull/41580, the pyarrow tests now also rely on a file in the parquet-testing submodule. And the path to that directory is controlled by `PARQUET_TEST_DATA`, which appears to be set wrongly in the wheel test scripts, causing all wheel builds to fail at the moment. * GitHub Issue: #43785 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- ci/scripts/python_wheel_unix_test.sh | 2 +- ci/scripts/python_wheel_windows_test.bat | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/scripts/python_wheel_unix_test.sh b/ci/scripts/python_wheel_unix_test.sh index a25e5c51bddbc..cf87a17056783 100755 --- a/ci/scripts/python_wheel_unix_test.sh +++ b/ci/scripts/python_wheel_unix_test.sh @@ -54,7 +54,7 @@ export PYARROW_TEST_S3=${ARROW_S3} export PYARROW_TEST_TENSORFLOW=ON export ARROW_TEST_DATA=${source_dir}/testing/data -export PARQUET_TEST_DATA=${source_dir}/submodules/parquet-testing/data +export PARQUET_TEST_DATA=${source_dir}/cpp/submodules/parquet-testing/data if [ "${INSTALL_PYARROW}" == "ON" ]; then # Install the built wheels diff --git a/ci/scripts/python_wheel_windows_test.bat b/ci/scripts/python_wheel_windows_test.bat index a928c3571d0cb..87c0bb1252024 100755 --- a/ci/scripts/python_wheel_windows_test.bat +++ b/ci/scripts/python_wheel_windows_test.bat @@ -35,7 +35,7 @@ set PYARROW_TEST_TENSORFLOW=ON @REM set PYARROW_TEST_PANDAS=ON set ARROW_TEST_DATA=C:\arrow\testing\data -set PARQUET_TEST_DATA=C:\arrow\submodules\parquet-testing\data +set PARQUET_TEST_DATA=C:\arrow\cpp\submodules\parquet-testing\data @REM Install testing dependencies pip install -r C:\arrow\python\requirements-wheel-test.txt || exit /B 1 From b4f7efe5bdc2218bb595b130b4f65237caecfa76 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Thu, 22 Aug 2024 14:45:00 +0200 Subject: [PATCH 14/32] GH-43787: [C++] Register the new Opaque extension type by default (#43788) This is to resolve #43787 > The Opaque extension type implementation for C++ (plus python bindings) was added in https://github.com/apache/arrow/pull/43458, but it was not registered by default, which we should do for canonical extension types (see https://github.com/apache/arrow/pull/43458#issuecomment-2302551404) Additionally, this adds `bool8` extension type builds with `ARROW_JSON=false` as discussed [here](https://github.com/apache/arrow/commit/525881987d0b9b4f464c3e3593a9a7b4e3c767d0#r145613657) ### Rationale for this change Canonical types should be registered by default if possible (except e.g. if they can't be compiled due to `ARROW_JSON=false`). ### What changes are included in this PR? This adds default registration for `opaque`, changes when `bool8` is built and moves all canonical tests under the same test target. ### Are these changes tested? Changes are tested by previously existing tests. ### Are there any user-facing changes? `opaue` will now be registered by default and `bool8` will be present in case `ARROW_JSON=false` at build time. * GitHub Issue: #43787 Authored-by: Rok Mihevc Signed-off-by: Rok Mihevc --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/extension/CMakeLists.txt | 18 ++++++----------- cpp/src/arrow/extension/bool8.h | 2 ++ cpp/src/arrow/extension/bool8_test.cc | 1 - cpp/src/arrow/extension/fixed_shape_tensor.h | 2 ++ cpp/src/arrow/extension/opaque.h | 2 ++ cpp/src/arrow/extension/opaque_test.cc | 2 -- cpp/src/arrow/extension_type.cc | 21 ++++++++++++-------- python/pyarrow/tests/test_extension_type.py | 5 ++--- 9 files changed, 28 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index fb7253b6fd69d..89f28ee416ede 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -374,6 +374,7 @@ set(ARROW_SRCS datum.cc device.cc extension_type.cc + extension/bool8.cc pretty_print.cc record_batch.cc result.cc @@ -906,7 +907,6 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON - extension/bool8.cc extension/fixed_shape_tensor.cc extension/opaque.cc json/options.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index fcd5fa529ab56..5cb4bc77af2a4 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,22 +15,16 @@ # specific language governing permissions and limitations # under the License. -add_arrow_test(test - SOURCES - bool8_test.cc - PREFIX - "arrow-extension-bool8") +set(CANONICAL_EXTENSION_TESTS bool8_test.cc) -add_arrow_test(test - SOURCES - fixed_shape_tensor_test.cc - PREFIX - "arrow-fixed-shape-tensor") +if(ARROW_JSON) + list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc) +endif() add_arrow_test(test SOURCES - opaque_test.cc + ${CANONICAL_EXTENSION_TESTS} PREFIX - "arrow-extension-opaque") + "arrow-canonical-extensions") arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/bool8.h b/cpp/src/arrow/extension/bool8.h index 02e629b28a867..fbb507639e272 100644 --- a/cpp/src/arrow/extension/bool8.h +++ b/cpp/src/arrow/extension/bool8.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" namespace arrow::extension { diff --git a/cpp/src/arrow/extension/bool8_test.cc b/cpp/src/arrow/extension/bool8_test.cc index eabcfcf62d32c..ee77332bc3257 100644 --- a/cpp/src/arrow/extension/bool8_test.cc +++ b/cpp/src/arrow/extension/bool8_test.cc @@ -19,7 +19,6 @@ #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" -#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" namespace arrow { diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 20ec20a64c2d4..80a602021c60b 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" namespace arrow { diff --git a/cpp/src/arrow/extension/opaque.h b/cpp/src/arrow/extension/opaque.h index 9814b391cbad6..5d3411798f88d 100644 --- a/cpp/src/arrow/extension/opaque.h +++ b/cpp/src/arrow/extension/opaque.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" #include "arrow/type.h" diff --git a/cpp/src/arrow/extension/opaque_test.cc b/cpp/src/arrow/extension/opaque_test.cc index 1629cdb39651c..16fcba3fa6bb0 100644 --- a/cpp/src/arrow/extension/opaque_test.cc +++ b/cpp/src/arrow/extension/opaque_test.cc @@ -25,7 +25,6 @@ #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" -#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/type_fwd.h" #include "arrow/util/checked_cast.h" @@ -169,7 +168,6 @@ TEST(OpaqueType, MetadataRoundTrip) { TEST(OpaqueType, BatchRoundTrip) { auto type = internal::checked_pointer_cast( extension::opaque(binary(), "geometry", "adbc.postgresql")); - ExtensionTypeGuard guard(type); auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); auto array = ExtensionType::WrapArray(type, storage); diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 685018f7de7b8..83c7ebed4f319 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -27,9 +27,10 @@ #include "arrow/array/util.h" #include "arrow/chunked_array.h" #include "arrow/config.h" -#ifdef ARROW_JSON #include "arrow/extension/bool8.h" +#ifdef ARROW_JSON #include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/opaque.h" #endif #include "arrow/status.h" #include "arrow/type.h" @@ -143,17 +144,21 @@ static std::once_flag registry_initialized; namespace internal { static void CreateGlobalRegistry() { + // Register canonical extension types + g_registry = std::make_shared(); + std::vector> ext_types{extension::bool8()}; #ifdef ARROW_JSON - // Register canonical extension types - auto fst_ext_type = - checked_pointer_cast(extension::fixed_shape_tensor(int64(), {})); - ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type)); - - auto bool8_ext_type = checked_pointer_cast(extension::bool8()); - ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type)); + ext_types.push_back(extension::fixed_shape_tensor(int64(), {})); + ext_types.push_back(extension::opaque(null(), "", "")); #endif + + // Register canonical extension types + for (const auto& ext_type : ext_types) { + ARROW_CHECK_OK( + g_registry->RegisterType(checked_pointer_cast(ext_type))); + } } } // namespace internal diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index b04ee85ec99ad..0d50c467e96bd 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1693,9 +1693,8 @@ def test_opaque_type(pickle_module, storage_type, storage): arr = pa.ExtensionArray.from_storage(opaque_type, storage) assert isinstance(arr, opaque_arr_class) - with registered_extension_type(opaque_type): - buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) - batch = ipc_read_batch(buf) + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) assert batch.column(0).type.extension_name == "arrow.opaque" assert isinstance(batch.column(0), opaque_arr_class) From 3e9384bbf4162ea060e867a753bce464b31e5e1c Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Thu, 22 Aug 2024 15:27:40 +0200 Subject: [PATCH 15/32] GH-43519: [Python] Set up wheel building for Python 3.13 (#43539) ### Rationale for this change Like #43519 mentionies, now that the first `rc` is out, it's probably time to add CI coverage for Python 3.13 (and also start building wheels). ### What changes are included in this PR? I'm fairly new to the build/CI processes of the project, but I tried to follow the same template as #37901. I'll follow up afterwards with adding CI coverage for the free-threaded build as well. * GitHub Issue: #43519 Lead-authored-by: Lysandros Nikolaou Co-authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- .env | 2 +- ci/docker/python-wheel-manylinux-test.dockerfile | 7 ++++--- ci/docker/python-wheel-manylinux.dockerfile | 2 +- .../python-wheel-windows-test-vs2019.dockerfile | 7 ++++--- ci/docker/python-wheel-windows-vs2019.dockerfile | 7 ++++--- ci/scripts/install_gcs_testbench.sh | 10 +++++++--- ci/scripts/install_python.sh | 14 +++++++++++--- ci/scripts/python_wheel_macos_build.sh | 2 -- dev/release/verify-release-candidate.sh | 6 +++--- dev/tasks/python-wheels/github.linux.yml | 5 +++++ dev/tasks/python-wheels/github.osx.yml | 2 +- dev/tasks/tasks.yml | 3 ++- docker-compose.yml | 9 ++++++--- python/pyproject.toml | 1 + python/requirements-wheel-build.txt | 5 +++++ python/requirements-wheel-test.txt | 7 +++++++ 16 files changed, 62 insertions(+), 27 deletions(-) diff --git a/.env b/.env index 1358aafe824a6..21f904c3208f6 100644 --- a/.env +++ b/.env @@ -95,7 +95,7 @@ VCPKG="943c5ef1c8f6b5e6ced092b242c8299caae2ff01" # 2024.04.26 Release # ci/docker/python-wheel-windows-vs2019.dockerfile. # This is a workaround for our CI problem that "archery docker build" doesn't # use pulled built images in dev/tasks/python-wheels/github.windows.yml. -PYTHON_WHEEL_WINDOWS_IMAGE_REVISION=2024-06-18 +PYTHON_WHEEL_WINDOWS_IMAGE_REVISION=2024-08-06 # Use conanio/${CONAN_BASE}:{CONAN_VERSION} for "docker-compose run --rm conan". # See https://github.com/conan-io/conan-docker-tools#readme and diff --git a/ci/docker/python-wheel-manylinux-test.dockerfile b/ci/docker/python-wheel-manylinux-test.dockerfile index cdd0ae3ced756..443ff9c53cbcb 100644 --- a/ci/docker/python-wheel-manylinux-test.dockerfile +++ b/ci/docker/python-wheel-manylinux-test.dockerfile @@ -16,8 +16,8 @@ # under the License. ARG arch -ARG python -FROM ${arch}/python:${python} +ARG python_image_tag +FROM ${arch}/python:${python_image_tag} # RUN pip install --upgrade pip @@ -27,4 +27,5 @@ COPY python/requirements-wheel-test.txt /arrow/python/ RUN pip install -r /arrow/python/requirements-wheel-test.txt COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ -RUN PYTHON=python /arrow/ci/scripts/install_gcs_testbench.sh default +ARG python +RUN PYTHON_VERSION=${python} /arrow/ci/scripts/install_gcs_testbench.sh default diff --git a/ci/docker/python-wheel-manylinux.dockerfile b/ci/docker/python-wheel-manylinux.dockerfile index cb39667af1e10..42f088fd8a22a 100644 --- a/ci/docker/python-wheel-manylinux.dockerfile +++ b/ci/docker/python-wheel-manylinux.dockerfile @@ -103,7 +103,7 @@ RUN vcpkg install \ # Configure Python for applications running in the bash shell of this Dockerfile ARG python=3.8 ENV PYTHON_VERSION=${python} -RUN PYTHON_ROOT=$(find /opt/python -name cp${PYTHON_VERSION/./}-*) && \ +RUN PYTHON_ROOT=$(find /opt/python -name cp${PYTHON_VERSION/./}-cp${PYTHON_VERSION/./}) && \ echo "export PATH=$PYTHON_ROOT/bin:\$PATH" >> /etc/profile.d/python.sh SHELL ["/bin/bash", "-i", "-c"] diff --git a/ci/docker/python-wheel-windows-test-vs2019.dockerfile b/ci/docker/python-wheel-windows-test-vs2019.dockerfile index 32bbb55e82689..5f488a4c285ff 100644 --- a/ci/docker/python-wheel-windows-test-vs2019.dockerfile +++ b/ci/docker/python-wheel-windows-test-vs2019.dockerfile @@ -40,10 +40,11 @@ ARG python=3.8 RUN (if "%python%"=="3.8" setx PYTHON_VERSION "3.8.10" && setx PATH "%PATH%;C:\Python38;C:\Python38\Scripts") & \ (if "%python%"=="3.9" setx PYTHON_VERSION "3.9.13" && setx PATH "%PATH%;C:\Python39;C:\Python39\Scripts") & \ (if "%python%"=="3.10" setx PYTHON_VERSION "3.10.11" && setx PATH "%PATH%;C:\Python310;C:\Python310\Scripts") & \ - (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.5" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ - (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.0" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") + (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.9" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ + (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.4" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") & \ + (if "%python%"=="3.13" setx PYTHON_VERSION "3.13.0-rc1" && setx PATH "%PATH%;C:\Python313;C:\Python313\Scripts") # Install archiver to extract xz archives -RUN choco install -r -y --no-progress python --version=%PYTHON_VERSION% & \ +RUN choco install -r -y --pre --no-progress python --version=%PYTHON_VERSION% & \ python -m pip install --no-cache-dir -U pip setuptools & \ choco install --no-progress -r -y archiver diff --git a/ci/docker/python-wheel-windows-vs2019.dockerfile b/ci/docker/python-wheel-windows-vs2019.dockerfile index ff42de939d91f..5a17e3e4c52c2 100644 --- a/ci/docker/python-wheel-windows-vs2019.dockerfile +++ b/ci/docker/python-wheel-windows-vs2019.dockerfile @@ -83,9 +83,10 @@ ARG python=3.8 RUN (if "%python%"=="3.8" setx PYTHON_VERSION "3.8.10" && setx PATH "%PATH%;C:\Python38;C:\Python38\Scripts") & \ (if "%python%"=="3.9" setx PYTHON_VERSION "3.9.13" && setx PATH "%PATH%;C:\Python39;C:\Python39\Scripts") & \ (if "%python%"=="3.10" setx PYTHON_VERSION "3.10.11" && setx PATH "%PATH%;C:\Python310;C:\Python310\Scripts") & \ - (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.5" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ - (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.0" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") -RUN choco install -r -y --no-progress python --version=%PYTHON_VERSION% + (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.9" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ + (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.4" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") & \ + (if "%python%"=="3.13" setx PYTHON_VERSION "3.13.0-rc1" && setx PATH "%PATH%;C:\Python313;C:\Python313\Scripts") +RUN choco install -r -y --pre --no-progress python --version=%PYTHON_VERSION% RUN python -m pip install -U pip setuptools COPY python/requirements-wheel-build.txt arrow/python/ diff --git a/ci/scripts/install_gcs_testbench.sh b/ci/scripts/install_gcs_testbench.sh index 2090290c99322..5471b3cc238ca 100755 --- a/ci/scripts/install_gcs_testbench.sh +++ b/ci/scripts/install_gcs_testbench.sh @@ -41,8 +41,12 @@ version=$1 if [[ "${version}" -eq "default" ]]; then version="v0.39.0" # Latests versions of Testbench require newer setuptools - ${PYTHON:-python3} -m pip install --upgrade setuptools + python3 -m pip install --upgrade setuptools fi -${PYTHON:-python3} -m pip install \ - "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz" +# This script is run with PYTHON undefined in some places, +# but those only use older pythons. +if [[ -z "${PYTHON_VERSION}" ]] || [[ "${PYTHON_VERSION}" != "3.13" ]]; then + python3 -m pip install \ + "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz" +fi diff --git a/ci/scripts/install_python.sh b/ci/scripts/install_python.sh index 5f962f02b911b..42d0e9ca179fb 100755 --- a/ci/scripts/install_python.sh +++ b/ci/scripts/install_python.sh @@ -28,8 +28,9 @@ declare -A versions versions=([3.8]=3.8.10 [3.9]=3.9.13 [3.10]=3.10.11 - [3.11]=3.11.5 - [3.12]=3.12.0) + [3.11]=3.11.9 + [3.12]=3.12.4 + [3.13]=3.13.0) if [ "$#" -ne 2 ]; then echo "Usage: $0 " @@ -46,7 +47,14 @@ full_version=${versions[$2]} if [ $platform = "macOS" ]; then echo "Downloading Python installer..." - if [ "$(uname -m)" = "arm64" ] || [ "$version" = "3.10" ] || [ "$version" = "3.11" ] || [ "$version" = "3.12" ]; then + if [ "$version" = "3.13" ]; + then + fname="python-${full_version}rc1-macos11.pkg" + elif [ "$(uname -m)" = "arm64" ] || \ + [ "$version" = "3.10" ] || \ + [ "$version" = "3.11" ] || \ + [ "$version" = "3.12" ]; + then fname="python-${full_version}-macos11.pkg" else fname="python-${full_version}-macosx10.9.pkg" diff --git a/ci/scripts/python_wheel_macos_build.sh b/ci/scripts/python_wheel_macos_build.sh index 3ed9d5d8dd12f..d5430f26748eb 100755 --- a/ci/scripts/python_wheel_macos_build.sh +++ b/ci/scripts/python_wheel_macos_build.sh @@ -48,13 +48,11 @@ fi echo "=== (${PYTHON_VERSION}) Install Python build dependencies ===" export PIP_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') -export PIP_TARGET_PLATFORM="macosx_${MACOSX_DEPLOYMENT_TARGET//./_}_${arch}" pip install \ --upgrade \ --only-binary=:all: \ --target $PIP_SITE_PACKAGES \ - --platform $PIP_TARGET_PLATFORM \ -r ${source_dir}/python/requirements-wheel-build.txt pip install "delocate>=0.10.3" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 6a36109dc2fc1..07e765a759ea0 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -1146,7 +1146,7 @@ test_linux_wheels() { local arch="x86_64" fi - local python_versions="${TEST_PYTHON_VERSIONS:-3.8 3.9 3.10 3.11 3.12}" + local python_versions="${TEST_PYTHON_VERSIONS:-3.8 3.9 3.10 3.11 3.12 3.13}" local platform_tags="${TEST_WHEEL_PLATFORM_TAGS:-manylinux_2_17_${arch}.manylinux2014_${arch} manylinux_2_28_${arch}}" for python in ${python_versions}; do @@ -1170,11 +1170,11 @@ test_macos_wheels() { # apple silicon processor if [ "$(uname -m)" = "arm64" ]; then - local python_versions="3.8 3.9 3.10 3.11 3.12" + local python_versions="3.8 3.9 3.10 3.11 3.12 3.13" local platform_tags="macosx_11_0_arm64" local check_flight=OFF else - local python_versions="3.8 3.9 3.10 3.11 3.12" + local python_versions="3.8 3.9 3.10 3.11 3.12 3.13" local platform_tags="macosx_10_15_x86_64" fi diff --git a/dev/tasks/python-wheels/github.linux.yml b/dev/tasks/python-wheels/github.linux.yml index 968c5da21897b..2854d4349fb7c 100644 --- a/dev/tasks/python-wheels/github.linux.yml +++ b/dev/tasks/python-wheels/github.linux.yml @@ -36,6 +36,11 @@ jobs: ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 {% endif %} PYTHON: "{{ python_version }}" + {% if python_version == "3.13" %} + PYTHON_IMAGE_TAG: "3.13-rc" + {% else %} + PYTHON_IMAGE_TAG: "{{ python_version }}" + {% endif %} steps: {{ macros.github_checkout_arrow()|indent }} diff --git a/dev/tasks/python-wheels/github.osx.yml b/dev/tasks/python-wheels/github.osx.yml index 8ceb468af89dd..b26aeba32b79b 100644 --- a/dev/tasks/python-wheels/github.osx.yml +++ b/dev/tasks/python-wheels/github.osx.yml @@ -121,7 +121,7 @@ jobs: source test-env/bin/activate pip install --upgrade pip wheel arch -{{ arch }} pip install -r arrow/python/requirements-wheel-test.txt - PYTHON=python arch -{{ arch }} arrow/ci/scripts/install_gcs_testbench.sh default + PYTHON_VERSION={{ python_version }} arch -{{ arch }} arrow/ci/scripts/install_gcs_testbench.sh default arch -{{ arch }} arrow/ci/scripts/python_wheel_unix_test.sh $(pwd)/arrow {{ macros.github_upload_releases("arrow/python/repaired_wheels/*.whl")|indent }} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index fe02fe9ce68b2..60114d6930878 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -389,7 +389,8 @@ tasks: ("3.9", "cp39", "cp39"), ("3.10", "cp310", "cp310"), ("3.11", "cp311", "cp311"), - ("3.12", "cp312", "cp312")] %} + ("3.12", "cp312", "cp312"), + ("3.13", "cp313", "cp313")] %} {############################## Wheel Linux ##################################} diff --git a/docker-compose.yml b/docker-compose.yml index 14eeeeee6e5ef..3045cf015bc26 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1096,9 +1096,10 @@ services: args: arch: ${ARCH} arch_short: ${ARCH_SHORT} - base: quay.io/pypa/manylinux2014_${ARCH_ALIAS}:2024-02-04-ea37246 + base: quay.io/pypa/manylinux2014_${ARCH_ALIAS}:2024-08-03-32dfa47 vcpkg: ${VCPKG} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} manylinux: 2014 context: . dockerfile: ci/docker/python-wheel-manylinux.dockerfile @@ -1119,9 +1120,10 @@ services: args: arch: ${ARCH} arch_short: ${ARCH_SHORT} - base: quay.io/pypa/manylinux_2_28_${ARCH_ALIAS}:2024-02-04-ea37246 + base: quay.io/pypa/manylinux_2_28_${ARCH_ALIAS}:2024-08-03-32dfa47 vcpkg: ${VCPKG} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} manylinux: 2_28 context: . dockerfile: ci/docker/python-wheel-manylinux.dockerfile @@ -1135,7 +1137,7 @@ services: command: /arrow/ci/scripts/python_wheel_manylinux_build.sh python-wheel-manylinux-test-imports: - image: ${ARCH}/python:${PYTHON} + image: ${ARCH}/python:${PYTHON_IMAGE_TAG} shm_size: 2G volumes: - .:/arrow:delegated @@ -1151,6 +1153,7 @@ services: args: arch: ${ARCH} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} context: . dockerfile: ci/docker/python-wheel-manylinux-test.dockerfile cache_from: diff --git a/python/pyproject.toml b/python/pyproject.toml index d863bb3e5f0ac..8ece65dd467bb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -48,6 +48,7 @@ classifiers = [ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ] maintainers = [ {name = "Apache Arrow Developers", email = "dev@arrow.apache.org"} diff --git a/python/requirements-wheel-build.txt b/python/requirements-wheel-build.txt index faa078d3d7fe7..2d448004768ce 100644 --- a/python/requirements-wheel-build.txt +++ b/python/requirements-wheel-build.txt @@ -1,3 +1,8 @@ +# Remove pre and extra index url once there's NumPy and Cython wheels for 3.13 +# on PyPI +--pre +--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" + cython>=0.29.31 oldest-supported-numpy>=0.14; python_version<'3.9' numpy>=2.0.0; python_version>='3.9' diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index c7ff63e339575..98ec2bd4fd4e4 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -1,3 +1,9 @@ +# Remove pre and extra index url once there's NumPy and Cython wheels for 3.13 +# on PyPI +--pre +--prefer-binary +--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" + cffi cython hypothesis @@ -12,5 +18,6 @@ tzdata; sys_platform == 'win32' numpy~=1.21.3; python_version < "3.11" numpy~=1.23.2; python_version == "3.11" numpy~=1.26.0; python_version == "3.12" +numpy~=2.1.0; python_version >= "3.13" pandas From 88d57cf41fde20adf14adca02e02d2cb92c83443 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 22 Aug 2024 08:45:19 -0500 Subject: [PATCH 16/32] MINOR: [CI][R] Undo #43636 now that the action is approved (#43730) Undo the pinning in #43636 now that INFRA has approved the quarto-dev action Authored-by: Jonathan Keane Signed-off-by: Antoine Pitrou --- .github/workflows/r.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index bf7eb99e7e990..2820d42470bca 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -86,19 +86,18 @@ jobs: run: | sudo apt-get install devscripts - # replace the SHA with v2 once INFRA-26031 is resolved - - uses: r-lib/actions/setup-r@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true install-r: false - - uses: r-lib/actions/setup-r-dependencies@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: any::rcmdcheck needs: check working-directory: src/r - - uses: r-lib/actions/check-r-package@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/check-r-package@v2 with: working-directory: src/r env: @@ -341,11 +340,11 @@ jobs: cd r/windows ls *.zip | xargs -n 1 unzip -uo rm -rf *.zip - - uses: r-lib/actions/setup-r@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r@v2 with: r-version: ${{ matrix.config.rversion }} Ncpus: 2 - - uses: r-lib/actions/setup-r-dependencies@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r-dependencies@v2 env: GITHUB_PAT: "${{ github.token }}" with: From 2e33e98f583035cd686455870e9cbf5fb6dc9966 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 22 Aug 2024 08:26:37 -0800 Subject: [PATCH 17/32] MINOR: [GO] fixup test case name in cast_test.go (#43780) --- go/arrow/compute/cast_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go index 2e748a2fee9c2..fa08467dd3946 100644 --- a/go/arrow/compute/cast_test.go +++ b/go/arrow/compute/cast_test.go @@ -2636,7 +2636,7 @@ func (c *CastSuite) TestStructToDifferentNullabilityStruct() { defer dest3Nullable.Release() checkCast(c.T(), srcNonNull, dest3Nullable, *compute.DefaultCastOptions(true)) }) - c.Run("non-nullable to nullable", func() { + c.Run("nullable to non-nullable", func() { fieldsSrcNullable := []arrow.Field{ {Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, {Name: "b", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, From 76e0f6254b75509d83e44fe8997bd14007907c4f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 22 Aug 2024 15:37:09 -0400 Subject: [PATCH 18/32] GH-43764: [Go][FlightSQL] Add NewPreparedStatement function (#43781) ### Rationale for this change Allowing creation of the prepared statement object outside of the client allows for logging, proxying, and handing off prepared statements if necessary. ### Are these changes tested? Yes * GitHub Issue: #43764 Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/flight/flightsql/client.go | 9 +++++++++ go/arrow/flight/flightsql/client_test.go | 21 +++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go index 4a600e5253e9b..4c9dc50135108 100644 --- a/go/arrow/flight/flightsql/client.go +++ b/go/arrow/flight/flightsql/client.go @@ -1102,6 +1102,15 @@ type PreparedStatement struct { closed bool } +// NewPreparedStatement creates a prepared statement object bound to the provided +// client using the given handle. In general, it should be sufficient to use the +// Prepare function a client and this wouldn't be needed. But this can be used +// to propagate a prepared statement from one client to another if needed or if +// proxying requests. +func NewPreparedStatement(client *Client, handle []byte) *PreparedStatement { + return &PreparedStatement{client: client, handle: handle} +} + // Execute executes the prepared statement on the server and returns a FlightInfo // indicating where to retrieve the response. If SetParameters has been called // then the parameter bindings will be sent before execution. diff --git a/go/arrow/flight/flightsql/client_test.go b/go/arrow/flight/flightsql/client_test.go index 7604b554cbc6c..d060161f94f0f 100644 --- a/go/arrow/flight/flightsql/client_test.go +++ b/go/arrow/flight/flightsql/client_test.go @@ -378,8 +378,10 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { createRsp := &mockDoActionClient{} defer createRsp.AssertExpectations(s.T()) createRsp.On("Recv").Return(&pb.Result{Body: data}, nil).Once() - createRsp.On("Recv").Return(&pb.Result{}, io.EOF) - createRsp.On("CloseSend").Return(nil) + createRsp.On("Recv").Return(&pb.Result{}, io.EOF).Once() + createRsp.On("Recv").Return(&pb.Result{Body: data}, nil).Once() + createRsp.On("Recv").Return(&pb.Result{}, io.EOF).Once() + createRsp.On("CloseSend").Return(nil).Twice() closeRsp := &mockDoActionClient{} defer closeRsp.AssertExpectations(s.T()) @@ -387,13 +389,13 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { closeRsp.On("CloseSend").Return(nil) s.mockClient.On("DoAction", flightsql.CreatePreparedStatementActionType, action.Body, s.callOpts). - Return(createRsp, nil) + Return(createRsp, nil).Twice() s.mockClient.On("DoAction", flightsql.ClosePreparedStatementActionType, closeAct.Body, s.callOpts). Return(closeRsp, nil) infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)} desc := getDesc(infoCmd) - s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil) + s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil).Twice() prepared, err := s.sqlClient.Prepare(context.TODO(), query, s.callOpts...) s.NoError(err) @@ -404,6 +406,17 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { info, err := prepared.Execute(context.TODO(), s.callOpts...) s.NoError(err) s.Equal(&emptyFlightInfo, info) + + prepared, err = s.sqlClient.Prepare(context.TODO(), query, s.callOpts...) + s.NoError(err) + + secondPrepare := flightsql.NewPreparedStatement(&s.sqlClient, prepared.Handle()) + s.Equal(string(secondPrepare.Handle()), "query") + defer secondPrepare.Close(context.TODO(), s.callOpts...) + + info, err = secondPrepare.Execute(context.TODO(), s.callOpts...) + s.NoError(err) + s.Equal(&emptyFlightInfo, info) } func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() { From d47b305bbce037af18ce65dc968074fe1681b4d4 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:04:59 -0400 Subject: [PATCH 19/32] GH-43624: [Go] Add JSON/UUID extension types, extend arrow -> parquet logical type mapping (#43679) ### Rationale for this change - Missing `JSON` extension type implementation. - Current precedent in C++ (and thereby PyArrow) is that canonical extension types do not require manual registration. - Issues like #43640 and #43624 suggest that we need to expose ways of configuring parquet types written from arrow records, but casting the underlying data presents challenges for a generalized approach. ### What changes are included in this PR? - Move `UUIDType` from `internal` to `arrow/extensions` - Implement `JSON` canonical extension type - Automatically register all canonical extension types at initialization - remove register/unregister from various locations these extension types are used - Add new `CustomParquetType` interface so extension types can specify their target `LogicalType` in Parquet - Refactor parquet `fieldToNode` to split up `PrimitiveNode` type mapping for leaves from `GroupNode` composition - Simplify parquet `LogicalType` to use only value receivers ### Are these changes tested? Yes ### Are there any user-facing changes? - `UUID` and `JSON` extension types are available to end users. - Canonical extension types will automatically be recognized in IPC without registration. - Users with their own extension type implementations may use the `CustomParquetType` interface to control Parquet conversion without needing to fork or upstream the change. * GitHub Issue: #43624 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- docs/source/status.rst | 6 + go/arrow/array/array_test.go | 4 +- go/arrow/array/diff_test.go | 4 +- go/arrow/array/extension_test.go | 10 - go/arrow/avro/reader_types.go | 4 +- go/arrow/avro/schema.go | 4 +- go/arrow/compute/exec/span_test.go | 6 +- go/arrow/csv/reader_test.go | 4 +- go/arrow/csv/writer_test.go | 6 +- go/arrow/datatype_extension_test.go | 18 +- go/arrow/extensions/bool8_test.go | 3 - go/arrow/extensions/extensions.go | 36 +++ go/arrow/extensions/json.go | 148 ++++++++++ go/arrow/extensions/json_test.go | 268 ++++++++++++++++++ go/arrow/extensions/opaque_test.go | 3 - go/arrow/extensions/uuid.go | 265 +++++++++++++++++ go/arrow/extensions/uuid_test.go | 257 +++++++++++++++++ .../internal/flight_integration/scenario.go | 4 - .../cmd/arrow-json-integration-test/main.go | 4 - go/arrow/ipc/metadata_test.go | 11 +- go/internal/types/extension_types.go | 227 +-------------- go/internal/types/extension_types_test.go | 95 ------- go/parquet/cmd/parquet_reader/main.go | 2 +- go/parquet/metadata/app_version.go | 2 +- go/parquet/pqarrow/encode_arrow_test.go | 82 ++++-- go/parquet/pqarrow/path_builder_test.go | 6 +- go/parquet/pqarrow/schema.go | 228 +++++++-------- go/parquet/pqarrow/schema_test.go | 15 +- go/parquet/schema/converted_types.go | 8 +- go/parquet/schema/logical_types.go | 30 +- go/parquet/schema/logical_types_test.go | 40 +-- go/parquet/schema/schema_element_test.go | 4 +- 32 files changed, 1221 insertions(+), 583 deletions(-) create mode 100644 go/arrow/extensions/extensions.go create mode 100644 go/arrow/extensions/json.go create mode 100644 go/arrow/extensions/json_test.go create mode 100644 go/arrow/extensions/uuid.go create mode 100644 go/arrow/extensions/uuid_test.go delete mode 100644 go/internal/types/extension_types_test.go diff --git a/docs/source/status.rst b/docs/source/status.rst index c232aa280befb..5e2c2cc19c890 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -119,6 +119,12 @@ Data Types +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Variable shape tensor | | | | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| JSON | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| UUID | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| 8-bit Boolean | ✓ | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: diff --git a/go/arrow/array/array_test.go b/go/arrow/array/array_test.go index 4d83766b4fa3e..4f0627c600078 100644 --- a/go/arrow/array/array_test.go +++ b/go/arrow/array/array_test.go @@ -21,9 +21,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/testing/tools" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" ) @@ -122,7 +122,7 @@ func TestMakeFromData(t *testing.T) { {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint64, ValueType: &testDataType{arrow.TIMESTAMP}}, dict: array.NewData(&testDataType{arrow.TIMESTAMP}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)}, {name: "extension", d: &testDataType{arrow.EXTENSION}, expPanic: true, expError: "arrow/array: DataType for ExtensionArray must implement arrow.ExtensionType"}, - {name: "extension", d: types.NewUUIDType()}, + {name: "extension", d: extensions.NewUUIDType()}, {name: "run end encoded", d: arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Int64), child: []arrow.ArrayData{ array.NewData(&testDataType{arrow.INT64}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */), diff --git a/go/arrow/array/diff_test.go b/go/arrow/array/diff_test.go index 65d212be11838..9c9ce6a53aed0 100644 --- a/go/arrow/array/diff_test.go +++ b/go/arrow/array/diff_test.go @@ -25,9 +25,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" ) type diffTestCase struct { @@ -861,7 +861,7 @@ func TestEdits_UnifiedDiff(t *testing.T) { }, { name: "extensions", - dataType: types.NewUUIDType(), + dataType: extensions.NewUUIDType(), baseJSON: `["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001"]`, targetJSON: `["00000000-0000-0000-0000-000000000001", "00000000-0000-0000-0000-000000000002"]`, want: `@@ -0, +0 @@ diff --git a/go/arrow/array/extension_test.go b/go/arrow/array/extension_test.go index 71ea9f105af7c..26245cf015dec 100644 --- a/go/arrow/array/extension_test.go +++ b/go/arrow/array/extension_test.go @@ -30,16 +30,6 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestParametricEquals() { p1Type := types.NewParametric1Type(6) p2Type := types.NewParametric1Type(6) diff --git a/go/arrow/avro/reader_types.go b/go/arrow/avro/reader_types.go index e07cd380d511f..dab2b33dce601 100644 --- a/go/arrow/avro/reader_types.go +++ b/go/arrow/avro/reader_types.go @@ -27,8 +27,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" ) type dataLoader struct { @@ -436,7 +436,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } return nil } - case *types.UUIDBuilder: + case *extensions.UUIDBuilder: f.appendFunc = func(data interface{}) error { switch dt := data.(type) { case nil: diff --git a/go/arrow/avro/schema.go b/go/arrow/avro/schema.go index 007dad06c19cd..a6de3718d3ccf 100644 --- a/go/arrow/avro/schema.go +++ b/go/arrow/avro/schema.go @@ -24,7 +24,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/internal/utils" avro "github.com/hamba/avro/v2" ) @@ -349,7 +349,7 @@ func avroLogicalToArrowField(n *schemaNode) { // The uuid logical type represents a random generated universally unique identifier (UUID). // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 case "uuid": - dt = types.NewUUIDType() + dt = extensions.NewUUIDType() // The date logical type represents a date within the calendar, with no reference to a particular // time zone or time of day. diff --git a/go/arrow/compute/exec/span_test.go b/go/arrow/compute/exec/span_test.go index f5beb45ee1494..018fbb7d623d9 100644 --- a/go/arrow/compute/exec/span_test.go +++ b/go/arrow/compute/exec/span_test.go @@ -29,6 +29,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/compute/exec" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/endian" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/arrow/scalar" "github.com/apache/arrow/go/v18/internal/types" @@ -192,9 +193,6 @@ func TestArraySpan_NumBuffers(t *testing.T) { Children []exec.ArraySpan } - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - tests := []struct { name string fields fields @@ -207,7 +205,7 @@ func TestArraySpan_NumBuffers(t *testing.T) { {"large binary", fields{Type: arrow.BinaryTypes.LargeBinary}, 3}, {"string", fields{Type: arrow.BinaryTypes.String}, 3}, {"large string", fields{Type: arrow.BinaryTypes.LargeString}, 3}, - {"extension", fields{Type: types.NewUUIDType()}, 2}, + {"extension", fields{Type: extensions.NewUUIDType()}, 2}, {"int32", fields{Type: arrow.PrimitiveTypes.Int32}, 2}, } for _, tt := range tests { diff --git a/go/arrow/csv/reader_test.go b/go/arrow/csv/reader_test.go index b0775b9b11a96..6a89d49704298 100644 --- a/go/arrow/csv/reader_test.go +++ b/go/arrow/csv/reader_test.go @@ -30,8 +30,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -356,7 +356,7 @@ func testCSVReader(t *testing.T, filepath string, withHeader bool, stringsCanBeN {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "date32", Type: arrow.PrimitiveTypes.Date32}, {Name: "date64", Type: arrow.PrimitiveTypes.Date64}, }, diff --git a/go/arrow/csv/writer_test.go b/go/arrow/csv/writer_test.go index be9ab961c3ef7..2ae01a6d49071 100644 --- a/go/arrow/csv/writer_test.go +++ b/go/arrow/csv/writer_test.go @@ -31,9 +31,9 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/float16" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" ) @@ -230,7 +230,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "null", Type: arrow.Null}, }, nil, @@ -285,7 +285,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo b.Field(22).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(23).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(24).(*array.FixedSizeBinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) - b.Field(25).(*types.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) + b.Field(25).(*extensions.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) b.Field(26).(*array.NullBuilder).AppendEmptyValues(3) for _, field := range b.Fields() { diff --git a/go/arrow/datatype_extension_test.go b/go/arrow/datatype_extension_test.go index c3e595f523e57..7244d377bd285 100644 --- a/go/arrow/datatype_extension_test.go +++ b/go/arrow/datatype_extension_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -50,24 +50,14 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestExtensionType() { e.Nil(arrow.GetExtensionType("uuid-unknown")) - e.NotNil(arrow.GetExtensionType("uuid")) + e.NotNil(arrow.GetExtensionType("arrow.uuid")) - e.Error(arrow.RegisterExtensionType(types.NewUUIDType())) + e.Error(arrow.RegisterExtensionType(extensions.NewUUIDType())) e.Error(arrow.UnregisterExtensionType("uuid-unknown")) - typ := types.NewUUIDType() + typ := extensions.NewUUIDType() e.Implements((*arrow.ExtensionType)(nil), typ) e.Equal(arrow.EXTENSION, typ.ID()) e.Equal("extension", typ.Name()) diff --git a/go/arrow/extensions/bool8_test.go b/go/arrow/extensions/bool8_test.go index 9f7365d1555fb..ff129e24bc8f0 100644 --- a/go/arrow/extensions/bool8_test.go +++ b/go/arrow/extensions/bool8_test.go @@ -178,9 +178,6 @@ func TestReinterpretStorageEqualToValues(t *testing.T) { func TestBool8TypeBatchIPCRoundTrip(t *testing.T) { typ := extensions.NewBool8Type() - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int8, strings.NewReader(`[-1, 0, 1, 2, null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/extensions.go b/go/arrow/extensions/extensions.go new file mode 100644 index 0000000000000..03c6923e95f4f --- /dev/null +++ b/go/arrow/extensions/extensions.go @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "github.com/apache/arrow/go/v18/arrow" +) + +var canonicalExtensionTypes = []arrow.ExtensionType{ + &Bool8Type{}, + &UUIDType{}, + &OpaqueType{}, + &JSONType{}, +} + +func init() { + for _, extType := range canonicalExtensionTypes { + if err := arrow.RegisterExtensionType(extType); err != nil { + panic(err) + } + } +} diff --git a/go/arrow/extensions/json.go b/go/arrow/extensions/json.go new file mode 100644 index 0000000000000..12c49f9c0a76d --- /dev/null +++ b/go/arrow/extensions/json.go @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "reflect" + "slices" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" +) + +var jsonSupportedStorageTypes = []arrow.DataType{ + arrow.BinaryTypes.String, + arrow.BinaryTypes.LargeString, + arrow.BinaryTypes.StringView, +} + +// JSONType represents a UTF-8 encoded JSON string as specified in RFC8259. +type JSONType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (b *JSONType) ParquetLogicalType() schema.LogicalType { + return schema.JSONLogicalType{} +} + +// NewJSONType creates a new JSONType with the specified storage type. +// storageType must be one of String, LargeString, StringView. +func NewJSONType(storageType arrow.DataType) (*JSONType, error) { + if !slices.Contains(jsonSupportedStorageTypes, storageType) { + return nil, fmt.Errorf("unsupported storage type for JSON extension type: %s", storageType) + } + return &JSONType{ExtensionBase: arrow.ExtensionBase{Storage: storageType}}, nil +} + +func (b *JSONType) ArrayType() reflect.Type { return reflect.TypeOf(JSONArray{}) } + +func (b *JSONType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !(data == "" || data == "{}") { + return nil, fmt.Errorf("serialized metadata for JSON extension type must be '' or '{}', found: %s", data) + } + return NewJSONType(storageType) +} + +func (b *JSONType) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() && arrow.TypeEqual(b.Storage, other.StorageType()) +} + +func (b *JSONType) ExtensionName() string { return "arrow.json" } + +func (b *JSONType) Serialize() string { return "" } + +func (b *JSONType) String() string { + return fmt.Sprintf("extension<%s[storage_type=%s]>", b.ExtensionName(), b.Storage) +} + +// JSONArray is logically an array of UTF-8 encoded JSON strings. +// Its values are unmarshaled to native Go values. +type JSONArray struct { + array.ExtensionArrayBase +} + +func (a *JSONArray) String() string { + b, err := a.MarshalJSON() + if err != nil { + panic(fmt.Sprintf("failed marshal JSONArray: %s", err)) + } + + return string(b) +} + +func (a *JSONArray) Value(i int) any { + val := a.ValueBytes(i) + + var res any + if err := json.Unmarshal(val, &res); err != nil { + panic(err) + } + + return res +} + +func (a *JSONArray) ValueStr(i int) string { + return string(a.ValueBytes(i)) +} + +func (a *JSONArray) ValueBytes(i int) []byte { + // convert to json.RawMessage, set to nil if elem isNull. + val := a.ValueJSON(i) + + // simply returns wrapped bytes, or null if val is nil. + b, err := val.MarshalJSON() + if err != nil { + panic(err) + } + + return b +} + +// ValueJSON wraps the underlying string value as a json.RawMessage, +// or returns nil if the array value is null. +func (a *JSONArray) ValueJSON(i int) json.RawMessage { + var val json.RawMessage + if a.IsValid(i) { + val = json.RawMessage(a.Storage().(array.StringLike).Value(i)) + } + return val +} + +// MarshalJSON implements json.Marshaler. +// Marshaling json.RawMessage is a no-op, except that nil values will +// be marshaled as a JSON null. +func (a *JSONArray) MarshalJSON() ([]byte, error) { + values := make([]json.RawMessage, a.Len()) + for i := 0; i < a.Len(); i++ { + values[i] = a.ValueJSON(i) + } + return json.Marshal(values) +} + +// GetOneForMarshal implements arrow.Array. +func (a *JSONArray) GetOneForMarshal(i int) interface{} { + return a.ValueJSON(i) +} + +var ( + _ arrow.ExtensionType = (*JSONType)(nil) + _ array.ExtensionArray = (*JSONArray)(nil) +) diff --git a/go/arrow/extensions/json_test.go b/go/arrow/extensions/json_test.go new file mode 100644 index 0000000000000..21acc58f93949 --- /dev/null +++ b/go/arrow/extensions/json_test.go @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONTypeBasics(t *testing.T) { + typ, err := extensions.NewJSONType(arrow.BinaryTypes.String) + require.NoError(t, err) + + typLarge, err := extensions.NewJSONType(arrow.BinaryTypes.LargeString) + require.NoError(t, err) + + typView, err := extensions.NewJSONType(arrow.BinaryTypes.StringView) + require.NoError(t, err) + + assert.Equal(t, "arrow.json", typ.ExtensionName()) + assert.Equal(t, "arrow.json", typLarge.ExtensionName()) + assert.Equal(t, "arrow.json", typView.ExtensionName()) + + assert.True(t, typ.ExtensionEquals(typ)) + assert.True(t, typLarge.ExtensionEquals(typLarge)) + assert.True(t, typView.ExtensionEquals(typView)) + + assert.False(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ)) + assert.False(t, arrow.TypeEqual(typ, typLarge)) + assert.False(t, arrow.TypeEqual(typ, typView)) + assert.False(t, arrow.TypeEqual(typLarge, typView)) + + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.LargeString, typLarge.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.StringView, typView.StorageType())) + + assert.Equal(t, "extension", typ.String()) + assert.Equal(t, "extension", typLarge.String()) + assert.Equal(t, "extension", typView.String()) +} + +var jsonTestCases = []struct { + Name string + StorageType arrow.DataType + StorageBuilder func(mem memory.Allocator) array.Builder +}{ + { + Name: "string", + StorageType: arrow.BinaryTypes.String, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringBuilder(mem) }, + }, + { + Name: "large_string", + StorageType: arrow.BinaryTypes.LargeString, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewLargeStringBuilder(mem) }, + }, + { + Name: "string_view", + StorageType: arrow.BinaryTypes.StringView, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringViewBuilder(mem) }, + }, +} + +func TestJSONTypeCreateFromArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + require.Equal(t, "foobar", jsonArr.Value(0)) + require.Equal(t, nil, jsonArr.Value(1)) + require.Equal(t, map[string]any{"foo": "bar"}, jsonArr.Value(2)) + require.Equal(t, float64(42), jsonArr.Value(3)) + require.Equal(t, true, jsonArr.Value(4)) + require.Equal(t, []any{float64(1), true, "3", nil, map[string]any{"five": float64(5)}}, jsonArr.Value(5)) + }) + } +} + +func TestJSONTypeBatchIPCRoundTrip(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) + }) + } +} + +func TestMarshallJSONArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + b, err := jsonArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := `["foobar",null,{"foo":"bar"},42,true,[1,true,"3",null,{"five":5}]]` + require.Equal(t, expectedJSON, string(b)) + require.Equal(t, expectedJSON, jsonArr.String()) + }) + } +} + +func TestJSONRecordToJSON(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "json", Type: typ, Nullable: true}}, nil), []arrow.Array{jsonArr}, 6) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"json":"foobar"} + {"json":null} + {"json":{"foo":"bar"}} + {"json":42} + {"json":true} + {"json":[1,true,"3",null,{"five":5}]} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } + }) + } +} diff --git a/go/arrow/extensions/opaque_test.go b/go/arrow/extensions/opaque_test.go index b6686e97bc027..a0fc8962ce5e4 100644 --- a/go/arrow/extensions/opaque_test.go +++ b/go/arrow/extensions/opaque_test.go @@ -161,9 +161,6 @@ func TestOpaqueTypeMetadataRoundTrip(t *testing.T) { func TestOpaqueTypeBatchRoundTrip(t *testing.T) { typ := extensions.NewOpaqueType(arrow.BinaryTypes.String, "geometry", "adbc.postgresql") - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.BinaryTypes.String, strings.NewReader(`["foobar", null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/uuid.go b/go/arrow/extensions/uuid.go new file mode 100644 index 0000000000000..422b9ea118800 --- /dev/null +++ b/go/arrow/extensions/uuid.go @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "bytes" + "fmt" + "reflect" + "strings" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" + "github.com/google/uuid" +) + +type UUIDBuilder struct { + *array.ExtensionBuilder +} + +// NewUUIDBuilder creates a new UUIDBuilder, exposing a convenient and efficient interface +// for writing uuid.UUID (or [16]byte) values to the underlying FixedSizeBinary storage array. +func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { + return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} +} + +func (b *UUIDBuilder) Append(v uuid.UUID) { + b.AppendBytes(v) +} + +func (b *UUIDBuilder) AppendBytes(v [16]byte) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) +} + +func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) +} + +func (b *UUIDBuilder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + uid, err := uuid.Parse(s) + if err != nil { + return err + } + + b.Append(uid) + return nil +} + +func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { + if len(v) != len(valid) && len(valid) != 0 { + panic("len(v) != len(valid) && len(valid) != 0") + } + + data := make([][]byte, len(v)) + for i := range v { + if len(valid) > 0 && !valid[i] { + continue + } + data[i] = v[i][:] + } + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) +} + +func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + var val uuid.UUID + switch v := t.(type) { + case string: + val, err = uuid.Parse(v) + if err != nil { + return err + } + case []byte: + val, err = uuid.ParseBytes(v) + if err != nil { + return err + } + case nil: + b.AppendNull() + return nil + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeOf([]byte{}), + Offset: dec.InputOffset(), + Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), + } + } + + b.Append(val) + return nil +} + +func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + + if delim, ok := t.(json.Delim); !ok || delim != '[' { + return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) + } + + return b.Unmarshal(dec) +} + +// UUIDArray is a simple array which is a FixedSizeBinary(16) +type UUIDArray struct { + array.ExtensionArrayBase +} + +func (a *UUIDArray) String() string { + arr := a.Storage().(*array.FixedSizeBinary) + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < arr.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(o, "%q", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *UUIDArray) Value(i int) uuid.UUID { + if a.IsNull(i) { + return uuid.Nil + } + return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) +} + +func (a *UUIDArray) Values() []uuid.UUID { + values := make([]uuid.UUID, a.Len()) + for i := range values { + values[i] = a.Value(i) + } + return values +} + +func (a *UUIDArray) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return a.Value(i).String() + } +} + +func (a *UUIDArray) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range vals { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (a *UUIDArray) GetOneForMarshal(i int) interface{} { + if a.IsValid(i) { + return a.Value(i) + } + return nil +} + +// UUIDType is a simple extension type that represents a FixedSizeBinary(16) +// to be used for representing UUIDs +type UUIDType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (e *UUIDType) ParquetLogicalType() schema.LogicalType { + return schema.UUIDLogicalType{} +} + +// NewUUIDType is a convenience function to create an instance of UUIDType +// with the correct storage type +func NewUUIDType() *UUIDType { + return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} +} + +// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays +func (*UUIDType) ArrayType() reflect.Type { + return reflect.TypeOf(UUIDArray{}) +} + +func (*UUIDType) ExtensionName() string { + return "arrow.uuid" +} + +func (e *UUIDType) String() string { + return fmt.Sprintf("extension<%s>", e.ExtensionName()) +} + +func (e *UUIDType) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil +} + +func (*UUIDType) Serialize() string { + return "" +} + +// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} +func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { + return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) + } + return NewUUIDType(), nil +} + +// ExtensionEquals returns true if both extensions have the same name +func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { + return e.ExtensionName() == other.ExtensionName() +} + +func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { + return NewUUIDBuilder(mem) +} + +var ( + _ arrow.ExtensionType = (*UUIDType)(nil) + _ array.CustomExtensionBuilder = (*UUIDType)(nil) + _ array.ExtensionArray = (*UUIDArray)(nil) + _ array.Builder = (*UUIDBuilder)(nil) +) diff --git a/go/arrow/extensions/uuid_test.go b/go/arrow/extensions/uuid_test.go new file mode 100644 index 0000000000000..80c621db2a0d5 --- /dev/null +++ b/go/arrow/extensions/uuid_test.go @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testUUID = uuid.New() + +func TestUUIDExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + builder := extensions.NewUUIDBuilder(mem) + builder.Append(testUUID) + builder.AppendNull() + builder.AppendBytes(testUUID) + arr := builder.NewArray() + defer arr.Release() + arrStr := arr.String() + assert.Equal(t, fmt.Sprintf(`["%[1]s" (null) "%[1]s"]`, testUUID), arrStr) + jsonStr, err := json.Marshal(arr) + assert.NoError(t, err) + + arr1, _, err := array.FromJSON(mem, extensions.NewUUIDType(), bytes.NewReader(jsonStr)) + defer arr1.Release() + assert.NoError(t, err) + assert.True(t, array.Equal(arr1, arr)) + + require.NoError(t, json.Unmarshal(jsonStr, builder)) + arr2 := builder.NewArray() + defer arr2.Release() + assert.True(t, array.Equal(arr2, arr)) +} + +func TestUUIDExtensionRecordBuilder(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + }, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + builder.Field(0).(*extensions.UUIDBuilder).AppendNull() + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + record := builder.NewRecord() + b, err := record.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n,{\"uuid\":null}\n,{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) + record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) + require.NoError(t, err) + require.Equal(t, record, record1) +} + +func TestUUIDStringRoundTrip(t *testing.T) { + // 1. create array + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + b := extensions.NewUUIDBuilder(mem) + b.Append(uuid.Nil) + b.AppendNull() + b.Append(uuid.NameSpaceURL) + b.AppendNull() + b.Append(testUUID) + + arr := b.NewArray() + defer arr.Release() + + // 2. create array via AppendValueFromString + b1 := extensions.NewUUIDBuilder(mem) + defer b1.Release() + + for i := 0; i < arr.Len(); i++ { + assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + } + + arr1 := b1.NewArray() + defer arr1.Release() + + assert.True(t, array.Equal(arr, arr1)) +} + +func TestUUIDTypeBasics(t *testing.T) { + typ := extensions.NewUUIDType() + + assert.Equal(t, "arrow.uuid", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.False(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ)) + assert.True(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func TestUUIDTypeCreateFromArray(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := array.NewFixedSizeBinaryBuilder(memory.DefaultAllocator, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + defer bldr.Release() + + bldr.Append(testUUID[:]) + bldr.AppendNull() + bldr.Append(testUUID[:]) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + require.Equal(t, testUUID, uuidArr.Value(0)) + require.Equal(t, uuid.Nil, uuidArr.Value(1)) + require.Equal(t, testUUID, uuidArr.Value(2)) +} + +func TestUUIDTypeBatchIPCRoundTrip(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) +} + +func TestMarshallUUIDArray(t *testing.T) { + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + b, err := uuidArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := fmt.Sprintf(`["%[1]s",null,"%[1]s"]`, testUUID) + require.Equal(t, expectedJSON, string(b)) +} + +func TestUUIDRecordToJSON(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + uuid1 := uuid.MustParse("8c607ed4-07b2-4b9c-b5eb-c0387357f9ae") + + bldr.Append(uuid1) + bldr.AppendNull() + + // c5f2cbd9-7094-491a-b267-167bb62efe02 + bldr.AppendBytes([16]byte{197, 242, 203, 217, 112, 148, 73, 26, 178, 103, 22, 123, 182, 46, 254, 2}) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "uuid", Type: typ, Nullable: true}}, nil), []arrow.Array{uuidArr}, 3) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"uuid":"8c607ed4-07b2-4b9c-b5eb-c0387357f9ae"} + {"uuid":null} + {"uuid":"c5f2cbd9-7094-491a-b267-167bb62efe02"} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } +} diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go index 1528bb05d9daa..b9535002a0a17 100644 --- a/go/arrow/internal/flight_integration/scenario.go +++ b/go/arrow/internal/flight_integration/scenario.go @@ -40,7 +40,6 @@ import ( "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -161,9 +160,6 @@ func (s *defaultIntegrationTester) RunClient(addr string, opts ...grpc.DialOptio ctx := context.Background() - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - descr := &flight.FlightDescriptor{ Type: flight.DescriptorPATH, Path: []string{s.path}, diff --git a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go index b3e1dcac14119..c47a091268be9 100644 --- a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go +++ b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go @@ -22,12 +22,10 @@ import ( "log" "os" - "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/arrio" "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" - "github.com/apache/arrow/go/v18/internal/types" ) func main() { @@ -50,8 +48,6 @@ func main() { } func runCommand(jsonName, arrowName, mode string, verbose bool) error { - arrow.RegisterExtensionType(types.NewUUIDType()) - if jsonName == "" { return fmt.Errorf("must specify json file name") } diff --git a/go/arrow/ipc/metadata_test.go b/go/arrow/ipc/metadata_test.go index 33bc63c2a0068..14b8da2cf7cf7 100644 --- a/go/arrow/ipc/metadata_test.go +++ b/go/arrow/ipc/metadata_test.go @@ -23,10 +23,10 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/dictutils" "github.com/apache/arrow/go/v18/arrow/internal/flatbuf" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -169,7 +169,7 @@ func TestRWFooter(t *testing.T) { } func exampleUUID(mem memory.Allocator) arrow.Array { - extType := types.NewUUIDType() + extType := extensions.NewUUIDType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() @@ -184,9 +184,6 @@ func TestUnrecognizedExtensionType(t *testing.T) { pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer pool.AssertSize(t, 0) - // register the uuid type - assert.NoError(t, arrow.RegisterExtensionType(types.NewUUIDType())) - extArr := exampleUUID(pool) defer extArr.Release() @@ -205,7 +202,9 @@ func TestUnrecognizedExtensionType(t *testing.T) { // unregister the uuid type before we read back the buffer so it is // unrecognized when reading back the record batch. - assert.NoError(t, arrow.UnregisterExtensionType("uuid")) + assert.NoError(t, arrow.UnregisterExtensionType("arrow.uuid")) + // re-register once the test is complete + defer arrow.RegisterExtensionType(extensions.NewUUIDType()) rdr, err := NewReader(&buf, WithAllocator(pool)) defer rdr.Release() diff --git a/go/internal/types/extension_types.go b/go/internal/types/extension_types.go index 85c64d86bffcb..33ada2d488f71 100644 --- a/go/internal/types/extension_types.go +++ b/go/internal/types/extension_types.go @@ -18,238 +18,15 @@ package types import ( - "bytes" "encoding/binary" "fmt" "reflect" - "strings" "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/google/uuid" "golang.org/x/xerrors" ) -var UUID = NewUUIDType() - -type UUIDBuilder struct { - *array.ExtensionBuilder -} - -func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { - return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} -} - -func (b *UUIDBuilder) Append(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) -} - -func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) -} - -func (b *UUIDBuilder) AppendValueFromString(s string) error { - if s == array.NullValueStr { - b.AppendNull() - return nil - } - - uid, err := uuid.Parse(s) - if err != nil { - return err - } - - b.Append(uid) - return nil -} - -func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("len(v) != len(valid) && len(valid) != 0") - } - - data := make([][]byte, len(v)) - for i := range v { - if len(valid) > 0 && !valid[i] { - continue - } - data[i] = v[i][:] - } - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) -} - -func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - var val uuid.UUID - switch v := t.(type) { - case string: - val, err = uuid.Parse(v) - if err != nil { - return err - } - case []byte: - val, err = uuid.ParseBytes(v) - if err != nil { - return err - } - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf([]byte{}), - Offset: dec.InputOffset(), - Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), - } - } - - b.Append(val) - return nil -} - -func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -// UUIDArray is a simple array which is a FixedSizeBinary(16) -type UUIDArray struct { - array.ExtensionArrayBase -} - -func (a *UUIDArray) String() string { - arr := a.Storage().(*array.FixedSizeBinary) - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < arr.Len(); i++ { - if i > 0 { - o.WriteString(" ") - } - switch { - case a.IsNull(i): - o.WriteString(array.NullValueStr) - default: - fmt.Fprintf(o, "%q", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *UUIDArray) Value(i int) uuid.UUID { - if a.IsNull(i) { - return uuid.Nil - } - return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) -} - -func (a *UUIDArray) ValueStr(i int) string { - switch { - case a.IsNull(i): - return array.NullValueStr - default: - return a.Value(i).String() - } -} - -func (a *UUIDArray) MarshalJSON() ([]byte, error) { - arr := a.Storage().(*array.FixedSizeBinary) - values := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - if a.IsValid(i) { - values[i] = uuid.Must(uuid.FromBytes(arr.Value(i))).String() - } - } - return json.Marshal(values) -} - -func (a *UUIDArray) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.Value(i) -} - -// UUIDType is a simple extension type that represents a FixedSizeBinary(16) -// to be used for representing UUIDs -type UUIDType struct { - arrow.ExtensionBase -} - -// NewUUIDType is a convenience function to create an instance of UUIDType -// with the correct storage type -func NewUUIDType() *UUIDType { - return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} -} - -// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays -func (*UUIDType) ArrayType() reflect.Type { - return reflect.TypeOf(UUIDArray{}) -} - -func (*UUIDType) ExtensionName() string { - return "uuid" -} - -func (e *UUIDType) String() string { - return fmt.Sprintf("extension_type", e.Storage) -} - -func (e *UUIDType) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil -} - -// Serialize returns "uuid-serialized" for testing proper metadata passing -func (*UUIDType) Serialize() string { - return "uuid-serialized" -} - -// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} and the data to be -// "uuid-serialized" in order to correctly create a UUIDType for testing deserialize. -func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "uuid-serialized" { - return nil, fmt.Errorf("type identifier did not match: '%s'", data) - } - if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { - return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) - } - return NewUUIDType(), nil -} - -// ExtensionEquals returns true if both extensions have the same name -func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { - return e.ExtensionName() == other.ExtensionName() -} - -func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { - return NewUUIDBuilder(mem) -} - // Parametric1Array is a simple int32 array for use with the Parametric1Type // in testing a parameterized user-defined extension type. type Parametric1Array struct { @@ -518,14 +295,14 @@ func (SmallintType) ArrayType() reflect.Type { return reflect.TypeOf(SmallintArr func (SmallintType) ExtensionName() string { return "smallint" } -func (SmallintType) Serialize() string { return "smallint" } +func (SmallintType) Serialize() string { return "smallint-serialized" } func (s *SmallintType) ExtensionEquals(other arrow.ExtensionType) bool { return s.Name() == other.Name() } func (SmallintType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "smallint" { + if data != "smallint-serialized" { return nil, fmt.Errorf("type identifier did not match: '%s'", data) } if !arrow.TypeEqual(storageType, arrow.PrimitiveTypes.Int16) { diff --git a/go/internal/types/extension_types_test.go b/go/internal/types/extension_types_test.go deleted file mode 100644 index 65f6353d01be1..0000000000000 --- a/go/internal/types/extension_types_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types_test - -import ( - "bytes" - "testing" - - "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var testUUID = uuid.New() - -func TestUUIDExtensionBuilder(t *testing.T) { - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - builder := types.NewUUIDBuilder(mem) - builder.Append(testUUID) - arr := builder.NewArray() - defer arr.Release() - arrStr := arr.String() - assert.Equal(t, "[\""+testUUID.String()+"\"]", arrStr) - jsonStr, err := json.Marshal(arr) - assert.NoError(t, err) - - arr1, _, err := array.FromJSON(mem, types.NewUUIDType(), bytes.NewReader(jsonStr)) - defer arr1.Release() - assert.NoError(t, err) - assert.Equal(t, arr, arr1) -} - -func TestUUIDExtensionRecordBuilder(t *testing.T) { - schema := arrow.NewSchema([]arrow.Field{ - {Name: "uuid", Type: types.NewUUIDType()}, - }, nil) - builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) - builder.Field(0).(*types.UUIDBuilder).Append(testUUID) - record := builder.NewRecord() - b, err := record.MarshalJSON() - require.NoError(t, err) - require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) - record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) - require.NoError(t, err) - require.Equal(t, record, record1) -} - -func TestUUIDStringRoundTrip(t *testing.T) { - // 1. create array - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - - b := types.NewUUIDBuilder(mem) - b.Append(uuid.Nil) - b.AppendNull() - b.Append(uuid.NameSpaceURL) - b.AppendNull() - b.Append(testUUID) - - arr := b.NewArray() - defer arr.Release() - - // 2. create array via AppendValueFromString - b1 := types.NewUUIDBuilder(mem) - defer b1.Release() - - for i := 0; i < arr.Len(); i++ { - assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) - } - - arr1 := b1.NewArray() - defer arr1.Release() - - assert.True(t, array.Equal(arr, arr1)) -} diff --git a/go/parquet/cmd/parquet_reader/main.go b/go/parquet/cmd/parquet_reader/main.go index 6e04f4254f9fa..4e480aeb8660b 100644 --- a/go/parquet/cmd/parquet_reader/main.go +++ b/go/parquet/cmd/parquet_reader/main.go @@ -154,7 +154,7 @@ func main() { if descr.ConvertedType() != schema.ConvertedTypes.None { fmt.Printf("/%s", descr.ConvertedType()) if descr.ConvertedType() == schema.ConvertedTypes.Decimal { - dec := descr.LogicalType().(*schema.DecimalLogicalType) + dec := descr.LogicalType().(schema.DecimalLogicalType) fmt.Printf("(%d,%d)", dec.Precision(), dec.Scale()) } } diff --git a/go/parquet/metadata/app_version.go b/go/parquet/metadata/app_version.go index 887ed79343a42..345e9d440a1ca 100644 --- a/go/parquet/metadata/app_version.go +++ b/go/parquet/metadata/app_version.go @@ -164,7 +164,7 @@ func (v AppVersion) HasCorrectStatistics(coltype parquet.Type, logicalType schem // parquet-cpp-arrow version 4.0.0 fixed Decimal comparisons for creating min/max stats // parquet-cpp also becomes parquet-cpp-arrow as of version 4.0.0 if v.App == "parquet-cpp" || (v.App == "parquet-cpp-arrow" && v.LessThan(parquet1655FixedVersion)) { - if _, ok := logicalType.(*schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { + if _, ok := logicalType.(schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { return false } } diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 16282173a685c..a238a78133e55 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/bitutil" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/types" @@ -715,16 +716,6 @@ type ParquetIOTestSuite struct { suite.Suite } -func (ps *ParquetIOTestSuite) SetupTest() { - ps.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (ps *ParquetIOTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.Repetition) *schema.GroupNode { byteWidth := int32(-1) @@ -2053,7 +2044,7 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(ps.T(), 0) - builder := types.NewUUIDBuilder(mem) + builder := extensions.NewUUIDBuilder(mem) builder.Append(uuid.New()) arr := builder.NewArray() defer arr.Release() @@ -2076,22 +2067,23 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `written` table with the extension type registered. - extType := types.NewUUIDType() + extType := types.NewSmallintType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() - bldr.Builder.(*array.FixedSizeBinaryBuilder).AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + bldr.Builder.(*array.Int16Builder).AppendValues( + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) + if arrow.GetExtensionType("smallint") != nil { + ps.NoError(arrow.UnregisterExtensionType("smallint")) + defer arrow.RegisterExtensionType(extType) } - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked written = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2101,16 +2093,16 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `expected` table with the extension type unregistered in the underlying type. - bldr := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + bldr := array.NewInt16Builder(mem) defer bldr.Release() bldr.AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked expected = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2147,13 +2139,55 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { ps.Truef(array.Equal(exc, tbc), "expected: %T %s\ngot: %T %s", exc, exc, tbc, tbc) expectedMd := arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: "uuid", - ipc.ExtensionMetadataKeyName: "uuid-serialized", + ipc.ExtensionTypeKeyName: "smallint", + ipc.ExtensionMetadataKeyName: "smallint-serialized", "PARQUET:field_id": "-1", }) ps.Truef(expectedMd.Equal(tbl.Column(0).Field().Metadata), "expected: %v\ngot: %v", expectedMd, tbl.Column(0).Field().Metadata) } +func (ps *ParquetIOTestSuite) TestArrowExtensionTypeLogicalType() { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(ps.T(), 0) + + jsonType, err := extensions.NewJSONType(arrow.BinaryTypes.String) + ps.NoError(err) + + sch := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + {Name: "json", Type: jsonType}, + }, + nil, + ) + bldr := array.NewRecordBuilder(mem, sch) + defer bldr.Release() + + bldr.Field(0).(*extensions.UUIDBuilder).Append(uuid.New()) + bldr.Field(1).(*array.ExtensionBuilder).AppendValueFromString(`{"hello": ["world", 2, true], "world": null}`) + rec := bldr.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + wr, err := pqarrow.NewFileWriter( + sch, + &buf, + parquet.NewWriterProperties(), + pqarrow.DefaultWriterProps(), + ) + ps.Require().NoError(err) + + ps.Require().NoError(wr.Write(rec)) + ps.Require().NoError(wr.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes())) + ps.Require().NoError(err) + defer rdr.Close() + + pqSchema := rdr.MetaData().Schema + ps.True(pqSchema.Column(0).LogicalType().Equals(schema.UUIDLogicalType{})) + ps.True(pqSchema.Column(1).LogicalType().Equals(schema.JSONLogicalType{})) +} + func TestWriteTableMemoryAllocation(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) sc := arrow.NewSchema([]arrow.Field{ @@ -2163,7 +2197,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true})}, {Name: "arr_i64", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)}, - {Name: "uuid", Type: types.NewUUIDType(), Nullable: true}, + {Name: "uuid", Type: extensions.NewUUIDType(), Nullable: true}, }, nil) bld := array.NewRecordBuilder(mem, sc) @@ -2176,7 +2210,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { abld := bld.Field(3).(*array.ListBuilder) abld.Append(true) abld.ValueBuilder().(*array.Int64Builder).Append(2) - bld.Field(4).(*types.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) + bld.Field(4).(*extensions.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) rec := bld.NewRecord() bld.Release() diff --git a/go/parquet/pqarrow/path_builder_test.go b/go/parquet/pqarrow/path_builder_test.go index 9bbae426b8a46..364f836d0bbca 100644 --- a/go/parquet/pqarrow/path_builder_test.go +++ b/go/parquet/pqarrow/path_builder_test.go @@ -22,8 +22,8 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -364,12 +364,12 @@ func TestNestedExtensionListsWithSomeNulls(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - listType := arrow.ListOf(types.NewUUIDType()) + listType := arrow.ListOf(extensions.NewUUIDType()) bldr := array.NewListBuilder(mem, listType) defer bldr.Release() nestedBldr := bldr.ValueBuilder().(*array.ListBuilder) - vb := nestedBldr.ValueBuilder().(*types.UUIDBuilder) + vb := nestedBldr.ValueBuilder().(*extensions.UUIDBuilder) uuid1 := uuid.New() uuid3 := uuid.New() diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go index ce5cc6f905084..4882077671f0f 100644 --- a/go/parquet/pqarrow/schema.go +++ b/go/parquet/pqarrow/schema.go @@ -25,7 +25,6 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/flight" - "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/file" @@ -120,6 +119,15 @@ func (sm *SchemaManifest) GetFieldIndices(indices []int) ([]int, error) { return ret, nil } +// ExtensionCustomParquetType is an interface that Arrow ExtensionTypes may implement +// to specify the target LogicalType to use when converting to Parquet. +// +// The PrimitiveType is not configurable, and is determined by a fixed mapping from +// the extension's StorageType to a Parquet type (see getParquetType in pqarrow source). +type ExtensionCustomParquetType interface { + ParquetLogicalType() schema.LogicalType +} + func isDictionaryReadSupported(dt arrow.DataType) bool { return arrow.IsBinaryLike(dt.ID()) } @@ -250,104 +258,14 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq } func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { - var ( - logicalType schema.LogicalType = schema.NoLogicalType{} - typ parquet.Type - repType = repFromNullable(field.Nullable) - length = -1 - precision = -1 - scale = -1 - err error - ) + repType := repFromNullable(field.Nullable) + // Handle complex types i.e. GroupNodes switch field.Type.ID() { case arrow.NULL: - typ = parquet.Types.Int32 - logicalType = &schema.NullLogicalType{} if repType != parquet.Repetitions.Optional { return nil, xerrors.New("nulltype arrow field must be nullable") } - case arrow.BOOL: - typ = parquet.Types.Boolean - case arrow.UINT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, false) - case arrow.INT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, true) - case arrow.UINT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, false) - case arrow.INT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, true) - case arrow.UINT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, false) - case arrow.INT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, true) - case arrow.UINT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, false) - case arrow.INT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, true) - case arrow.FLOAT32: - typ = parquet.Types.Float - case arrow.FLOAT64: - typ = parquet.Types.Double - case arrow.STRING, arrow.LARGE_STRING: - logicalType = schema.StringLogicalType{} - fallthrough - case arrow.BINARY, arrow.LARGE_BINARY: - typ = parquet.Types.ByteArray - case arrow.FIXED_SIZE_BINARY: - typ = parquet.Types.FixedLenByteArray - length = field.Type.(*arrow.FixedSizeBinaryType).ByteWidth - case arrow.DECIMAL, arrow.DECIMAL256: - dectype := field.Type.(arrow.DecimalType) - precision = int(dectype.GetPrecision()) - scale = int(dectype.GetScale()) - - if props.StoreDecimalAsInteger() && 1 <= precision && precision <= 18 { - if precision <= 9 { - typ = parquet.Types.Int32 - } else { - typ = parquet.Types.Int64 - } - } else { - typ = parquet.Types.FixedLenByteArray - length = int(DecimalSize(int32(precision))) - } - - logicalType = schema.NewDecimalLogicalType(int32(precision), int32(scale)) - case arrow.DATE32: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.DATE64: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.TIMESTAMP: - typ, logicalType, err = getTimestampMeta(field.Type.(*arrow.TimestampType), props, arrprops) - if err != nil { - return nil, err - } - case arrow.TIME32: - typ = parquet.Types.Int32 - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMillis) - case arrow.TIME64: - typ = parquet.Types.Int64 - timeType := field.Type.(*arrow.Time64Type) - if timeType.Unit == arrow.Nanosecond { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitNanos) - } else { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMicros) - } - case arrow.FLOAT16: - typ = parquet.Types.FixedLenByteArray - length = arrow.Float16SizeBytes - logicalType = schema.Float16LogicalType{} case arrow.STRUCT: return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops) case arrow.FIXED_SIZE_LIST, arrow.LIST: @@ -369,16 +287,6 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties dictType := field.Type.(*arrow.DictionaryType) return fieldToNode(name, arrow.Field{Name: name, Type: dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata}, props, arrprops) - case arrow.EXTENSION: - return fieldToNode(name, arrow.Field{ - Name: name, - Type: field.Type.(arrow.ExtensionType).StorageType(), - Nullable: field.Nullable, - Metadata: arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: field.Type.(arrow.ExtensionType).ExtensionName(), - ipc.ExtensionMetadataKeyName: field.Type.(arrow.ExtensionType).Serialize(), - }), - }, props, arrprops) case arrow.MAP: mapType := field.Type.(*arrow.MapType) keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops) @@ -402,8 +310,12 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties }, -1) } return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1) - default: - return nil, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, field.Type.ID()) + } + + // Not a GroupNode + typ, logicalType, length, err := getParquetType(field.Type, props, arrprops) + if err != nil { + return nil, err } return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata)) @@ -472,7 +384,7 @@ func (s schemaTree) RecordLeaf(leaf *SchemaField) { s.manifest.ColIndexToField[leaf.ColIndex] = leaf } -func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { +func arrowInt(log schema.IntLogicalType) (arrow.DataType, error) { switch log.BitWidth() { case 8: if log.IsSigned() { @@ -499,7 +411,7 @@ func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { } } -func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime32(logical schema.TimeLogicalType) (arrow.DataType, error) { if logical.TimeUnit() == schema.TimeUnitMillis { return arrow.FixedWidthTypes.Time32ms, nil } @@ -507,7 +419,7 @@ func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { return nil, xerrors.New(logical.String() + " cannot annotate a time32") } -func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime64(logical schema.TimeLogicalType) (arrow.DataType, error) { switch logical.TimeUnit() { case schema.TimeUnitMicros: return arrow.FixedWidthTypes.Time64us, nil @@ -518,7 +430,7 @@ func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { } } -func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error) { +func arrowTimestamp(logical schema.TimestampLogicalType) (arrow.DataType, error) { tz := "" // ConvertedTypes are adjusted to UTC per backward compatibility guidelines @@ -539,7 +451,7 @@ func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error } } -func arrowDecimal(logical *schema.DecimalLogicalType) arrow.DataType { +func arrowDecimal(logical schema.DecimalLogicalType) arrow.DataType { if logical.Precision() <= decimal128.MaxPrecision { return &arrow.Decimal128Type{Precision: logical.Precision(), Scale: logical.Scale()} } @@ -550,11 +462,11 @@ func arrowFromInt32(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.NoLogicalType: return arrow.PrimitiveTypes.Int32, nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime32(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) case schema.DateLogicalType: return arrow.FixedWidthTypes.Date32, nil @@ -569,13 +481,13 @@ func arrowFromInt64(logical schema.LogicalType) (arrow.DataType, error) { } switch logtype := logical.(type) { - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime64(logtype) - case *schema.TimestampLogicalType: + case schema.TimestampLogicalType: return arrowTimestamp(logtype) default: return nil, xerrors.New(logical.String() + " cannot annotate int64") @@ -586,7 +498,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.StringLogicalType: return arrow.BinaryTypes.String, nil - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.EnumLogicalType, @@ -600,7 +512,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, error) { switch logtype := logical.(type) { - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType: return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil @@ -611,6 +523,84 @@ func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, erro } } +func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (parquet.Type, schema.LogicalType, int, error) { + switch typ.ID() { + case arrow.NULL: + return parquet.Types.Int32, schema.NullLogicalType{}, -1, nil + case arrow.BOOL: + return parquet.Types.Boolean, schema.NoLogicalType{}, -1, nil + case arrow.UINT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, false), -1, nil + case arrow.INT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, true), -1, nil + case arrow.UINT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, false), -1, nil + case arrow.INT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, true), -1, nil + case arrow.UINT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, false), -1, nil + case arrow.INT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, true), -1, nil + case arrow.UINT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, false), -1, nil + case arrow.INT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, true), -1, nil + case arrow.FLOAT32: + return parquet.Types.Float, schema.NoLogicalType{}, -1, nil + case arrow.FLOAT64: + return parquet.Types.Double, schema.NoLogicalType{}, -1, nil + case arrow.STRING, arrow.LARGE_STRING: + return parquet.Types.ByteArray, schema.StringLogicalType{}, -1, nil + case arrow.BINARY, arrow.LARGE_BINARY: + return parquet.Types.ByteArray, schema.NoLogicalType{}, -1, nil + case arrow.FIXED_SIZE_BINARY: + return parquet.Types.FixedLenByteArray, schema.NoLogicalType{}, typ.(*arrow.FixedSizeBinaryType).ByteWidth, nil + case arrow.DECIMAL, arrow.DECIMAL256: + dectype := typ.(arrow.DecimalType) + precision := int(dectype.GetPrecision()) + scale := int(dectype.GetScale()) + + if !props.StoreDecimalAsInteger() || precision > 18 { + return parquet.Types.FixedLenByteArray, schema.NewDecimalLogicalType(int32(precision), int32(scale)), int(DecimalSize(int32(precision))), nil + } + + pqType := parquet.Types.Int32 + if precision > 9 { + pqType = parquet.Types.Int64 + } + + return pqType, schema.NoLogicalType{}, -1, nil + case arrow.DATE32: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.DATE64: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.TIMESTAMP: + pqType, logicalType, err := getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops) + return pqType, logicalType, -1, err + case arrow.TIME32: + return parquet.Types.Int32, schema.NewTimeLogicalType(true, schema.TimeUnitMillis), -1, nil + case arrow.TIME64: + pqTimeUnit := schema.TimeUnitMicros + if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond { + pqTimeUnit = schema.TimeUnitNanos + } + + return parquet.Types.Int64, schema.NewTimeLogicalType(true, pqTimeUnit), -1, nil + case arrow.FLOAT16: + return parquet.Types.FixedLenByteArray, schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil + case arrow.EXTENSION: + storageType := typ.(arrow.ExtensionType).StorageType() + pqType, logicalType, length, err := getParquetType(storageType, props, arrprops) + if withCustomType, ok := typ.(ExtensionCustomParquetType); ok { + logicalType = withCustomType.ParquetLogicalType() + } + + return pqType, logicalType, length, err + default: + return parquet.Type(0), nil, 0, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, typ.ID()) + } +} + func getArrowType(physical parquet.Type, logical schema.LogicalType, typeLen int) (arrow.DataType, error) { if !logical.IsValid() || logical.Equals(schema.NullLogicalType{}) { return arrow.Null, nil diff --git a/go/parquet/pqarrow/schema_test.go b/go/parquet/pqarrow/schema_test.go index 24b031c174bf2..528200fd0e7d9 100644 --- a/go/parquet/pqarrow/schema_test.go +++ b/go/parquet/pqarrow/schema_test.go @@ -21,10 +21,10 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/flight" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/metadata" "github.com/apache/arrow/go/v18/parquet/pqarrow" @@ -34,7 +34,7 @@ import ( ) func TestGetOriginSchemaBase64(t *testing.T) { - uuidType := types.NewUUIDType() + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) extMd := arrow.NewMetadata([]string{ipc.ExtensionMetadataKeyName, ipc.ExtensionTypeKeyName, "PARQUET:field_id"}, []string{uuidType.Serialize(), uuidType.ExtensionName(), "-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ @@ -44,10 +44,6 @@ func TestGetOriginSchemaBase64(t *testing.T) { }, nil) arrSerializedSc := flight.SerializeSchema(origArrSc, memory.DefaultAllocator) - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - defer arrow.UnregisterExtensionType(uuidType.ExtensionName()) pqschema, err := pqarrow.ToParquet(origArrSc, nil, pqarrow.DefaultWriterProps()) require.NoError(t, err) @@ -71,11 +67,7 @@ func TestGetOriginSchemaBase64(t *testing.T) { } func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { - uuidType := types.NewUUIDType() - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ {Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md}, @@ -90,6 +82,7 @@ func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { kv.Append("ARROW:schema", base64.StdEncoding.EncodeToString(arrSerializedSc)) arrow.UnregisterExtensionType(uuidType.ExtensionName()) + defer arrow.RegisterExtensionType(uuidType) arrsc, err := pqarrow.FromParquet(pqschema, nil, kv) require.NoError(t, err) diff --git a/go/parquet/schema/converted_types.go b/go/parquet/schema/converted_types.go index 5fc10f61cebc1..b2b6f50cbf682 100644 --- a/go/parquet/schema/converted_types.go +++ b/go/parquet/schema/converted_types.go @@ -113,13 +113,9 @@ func (p ConvertedType) ToLogicalType(convertedDecimal DecimalMetadata) LogicalTy case ConvertedTypes.TimeMicros: return NewTimeLogicalType(true /* adjustedToUTC */, TimeUnitMicros) case ConvertedTypes.TimestampMillis: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMillis) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMillis), WithTSFromConverted()) case ConvertedTypes.TimestampMicros: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMicros) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMicros), WithTSFromConverted()) case ConvertedTypes.Interval: return IntervalLogicalType{} case ConvertedTypes.Int8: diff --git a/go/parquet/schema/logical_types.go b/go/parquet/schema/logical_types.go index e8adce1ca140e..fa46ea0172f76 100644 --- a/go/parquet/schema/logical_types.go +++ b/go/parquet/schema/logical_types.go @@ -45,21 +45,21 @@ func getLogicalType(l *format.LogicalType) LogicalType { case l.IsSetENUM(): return EnumLogicalType{} case l.IsSetDECIMAL(): - return &DecimalLogicalType{typ: l.DECIMAL} + return DecimalLogicalType{typ: l.DECIMAL} case l.IsSetDATE(): return DateLogicalType{} case l.IsSetTIME(): if timeUnitFromThrift(l.TIME.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Time logical type") } - return &TimeLogicalType{typ: l.TIME} + return TimeLogicalType{typ: l.TIME} case l.IsSetTIMESTAMP(): if timeUnitFromThrift(l.TIMESTAMP.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Timestamp logical type") } - return &TimestampLogicalType{typ: l.TIMESTAMP} + return TimestampLogicalType{typ: l.TIMESTAMP} case l.IsSetINTEGER(): - return &IntLogicalType{typ: l.INTEGER} + return IntLogicalType{typ: l.INTEGER} case l.IsSetUNKNOWN(): return NullLogicalType{} case l.IsSetJSON(): @@ -344,7 +344,7 @@ func NewDecimalLogicalType(precision int32, scale int32) LogicalType { if scale < 0 || scale > precision { panic("parquet: scale must be a non-negative integer that does not exceed precision for decimal logical type") } - return &DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} + return DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} } // DecimalLogicalType is used to represent a decimal value of a given @@ -405,7 +405,7 @@ func (t DecimalLogicalType) toThrift() *format.LogicalType { } func (t DecimalLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*DecimalLogicalType) + other, ok := rhs.(DecimalLogicalType) if !ok { return false } @@ -509,7 +509,7 @@ func createTimeUnit(unit TimeUnitType) *format.TimeUnit { // NewTimeLogicalType returns a time type of the given unit. func NewTimeLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimeLogicalType{typ: &format.TimeType{ + return TimeLogicalType{typ: &format.TimeType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), }} @@ -584,7 +584,7 @@ func (t TimeLogicalType) toThrift() *format.LogicalType { } func (t TimeLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimeLogicalType) + other, ok := rhs.(TimeLogicalType) if !ok { return false } @@ -595,7 +595,7 @@ func (t TimeLogicalType) Equals(rhs LogicalType) bool { // NewTimestampLogicalType returns a logical timestamp type with "forceConverted" // set to false func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -608,7 +608,7 @@ func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalTyp // NewTimestampLogicalTypeForce returns a timestamp logical type with // "forceConverted" set to true func NewTimestampLogicalTypeForce(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -654,14 +654,14 @@ func WithTSFromConverted() TimestampOpt { // // TimestampType Unit defaults to milliseconds (TimeUnitMillis) func NewTimestampLogicalTypeWithOpts(opts ...TimestampOpt) LogicalType { - ts := &TimestampLogicalType{ + ts := TimestampLogicalType{ typ: &format.TimestampType{ Unit: createTimeUnit(TimeUnitMillis), // default to milliseconds }, } for _, o := range opts { - o(ts) + o(&ts) } return ts @@ -760,7 +760,7 @@ func (t TimestampLogicalType) toThrift() *format.LogicalType { } func (t TimestampLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimestampLogicalType) + other, ok := rhs.(TimestampLogicalType) if !ok { return false } @@ -778,7 +778,7 @@ func NewIntLogicalType(bitWidth int8, signed bool) LogicalType { default: panic("parquet: bit width must be exactly 8, 16, 32, or 64 for Int logical type") } - return &IntLogicalType{ + return IntLogicalType{ typ: &format.IntType{ BitWidth: bitWidth, IsSigned: signed, @@ -864,7 +864,7 @@ func (t IntLogicalType) toThrift() *format.LogicalType { } func (t IntLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*IntLogicalType) + other, ok := rhs.(IntLogicalType) if !ok { return false } diff --git a/go/parquet/schema/logical_types_test.go b/go/parquet/schema/logical_types_test.go index e33925966e178..395d1504182fe 100644 --- a/go/parquet/schema/logical_types_test.go +++ b/go/parquet/schema/logical_types_test.go @@ -38,18 +38,18 @@ func TestConvertedLogicalEquivalences(t *testing.T) { {"list", schema.ConvertedTypes.List, schema.NewListLogicalType(), schema.NewListLogicalType()}, {"enum", schema.ConvertedTypes.Enum, schema.EnumLogicalType{}, schema.EnumLogicalType{}}, {"date", schema.ConvertedTypes.Date, schema.DateLogicalType{}, schema.DateLogicalType{}}, - {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimestampLogicalType{}}, - {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimestampLogicalType{}}, - {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, + {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimestampLogicalType{}}, + {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimestampLogicalType{}}, + {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, {"json", schema.ConvertedTypes.JSON, schema.JSONLogicalType{}, schema.JSONLogicalType{}}, {"bson", schema.ConvertedTypes.BSON, schema.BSONLogicalType{}, schema.BSONLogicalType{}}, {"interval", schema.ConvertedTypes.Interval, schema.IntervalLogicalType{}, schema.IntervalLogicalType{}}, @@ -72,8 +72,8 @@ func TestConvertedLogicalEquivalences(t *testing.T) { fromMake := schema.NewDecimalLogicalType(10, 4) assert.IsType(t, fromMake, fromConverted) assert.True(t, fromConverted.Equals(fromMake)) - assert.IsType(t, &schema.DecimalLogicalType{}, fromConverted) - assert.IsType(t, &schema.DecimalLogicalType{}, fromMake) + assert.IsType(t, schema.DecimalLogicalType{}, fromConverted) + assert.IsType(t, schema.DecimalLogicalType{}, fromMake) assert.True(t, schema.NewDecimalLogicalType(16, 0).Equals(schema.NewDecimalLogicalType(16, 0))) }) } @@ -160,12 +160,12 @@ func TestNewTypeIncompatibility(t *testing.T) { {"uuid", schema.UUIDLogicalType{}, schema.UUIDLogicalType{}}, {"float16", schema.Float16LogicalType{}, schema.Float16LogicalType{}}, {"null", schema.NullLogicalType{}, schema.NullLogicalType{}}, - {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, - {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, + {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, + {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, } for _, tt := range tests { diff --git a/go/parquet/schema/schema_element_test.go b/go/parquet/schema/schema_element_test.go index 7da55ce93abe6..e427ba6485e64 100644 --- a/go/parquet/schema/schema_element_test.go +++ b/go/parquet/schema/schema_element_test.go @@ -192,7 +192,7 @@ func (s *SchemaElementConstructionSuite) TestSimple() { func (s *SchemaElementConstructionSuite) reconstructDecimal(c schemaElementConstructArgs) *decimalSchemaElementConstruction { ret := s.reconstruct(c) - dec := c.logical.(*DecimalLogicalType) + dec := c.logical.(DecimalLogicalType) return &decimalSchemaElementConstruction{*ret, int(dec.Precision()), int(dec.Scale())} } @@ -359,7 +359,7 @@ func (s *SchemaElementConstructionSuite) TestTemporal() { func (s *SchemaElementConstructionSuite) reconstructInteger(c schemaElementConstructArgs) *intSchemaElementConstruction { base := s.reconstruct(c) - l := c.logical.(*IntLogicalType) + l := c.logical.(IntLogicalType) return &intSchemaElementConstruction{ *base, l.BitWidth(), From 82ecf3e6ed8cb58a08d600041617ce85c9bdb7c1 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 22 Aug 2024 22:57:14 +0200 Subject: [PATCH 20/32] MINOR: [CI][C++][Python] Fix Cuda builds on git main (#43789) On the Cuda self-hosted runners, we need to use legacy `docker-compose` on all Archery Docker invocations, including the "image push" step. This is because the Docker client version on those runners is too old to accept the `--file` option to the `compose` subcommand. This is a followup to https://github.com/apache/arrow/pull/43586 . The image push step cannot easily be verified in a PR, hence this second PR. Authored-by: Antoine Pitrou Signed-off-by: Sutou Kouhei --- dev/tasks/docker-tests/github.cuda.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/tasks/docker-tests/github.cuda.yml b/dev/tasks/docker-tests/github.cuda.yml index 9c7adf53a6f70..8c04da8a91a4f 100644 --- a/dev/tasks/docker-tests/github.cuda.yml +++ b/dev/tasks/docker-tests/github.cuda.yml @@ -26,6 +26,8 @@ jobs: runs-on: ['self-hosted', 'cuda'] {{ macros.github_set_env(env) }} timeout-minutes: {{ timeout|default(60) }} + env: + ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 steps: {{ macros.github_checkout_arrow(fetch_depth=fetch_depth|default(1))|indent }} # python 3.8 is installed on the runner, no need to install @@ -34,7 +36,6 @@ jobs: - name: Execute Docker Build shell: bash env: - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 {{ macros.github_set_sccache_envvars()|indent(8) }} run: | archery docker run \ From bad064f705ec9fc72efac2d13a1fc3fac6d3d137 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 22 Aug 2024 14:08:26 -0700 Subject: [PATCH 21/32] MINOR: [C++] Ensure setting the default CMAKE_BUILD_TYPE (#43794) ### Rationale for this change The current logic for detecting whether the `CMAKE_BUILD_TYPE` is set is incorrect. That variable is never fully undefined; by default, in cases where it is unset is actually set to the empty string. Therefore, the condition that must be checked is not whether the variable is defined, but whether it tests to a truthy value (i.e. is a non-empty string). I consider this a minor change so I have not opened an associated issue. ### What changes are included in this PR? This PR changes `if(NOT DEFINED CMAKE_BUILD_TYPE)` to `if(NOT CMAKE_BUILD_TYPE)`. ### Are these changes tested? Since this fixes a particular CMake build scenario I am not sure if a test is merited, or where one would be added. ### Are there any user-facing changes? No. Authored-by: Vyas Ramasubramani Signed-off-by: Sutou Kouhei --- cpp/CMakeLists.txt | 2 +- cpp/examples/minimal_build/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a1e3138da9e0b..5ead9e4b063cd 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -84,7 +84,7 @@ set(ARROW_VERSION "18.0.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}") # if no build type is specified, default to release builds -if(NOT DEFINED CMAKE_BUILD_TYPE) +if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.") diff --git a/cpp/examples/minimal_build/CMakeLists.txt b/cpp/examples/minimal_build/CMakeLists.txt index b4a7cde938c87..95dad34221add 100644 --- a/cpp/examples/minimal_build/CMakeLists.txt +++ b/cpp/examples/minimal_build/CMakeLists.txt @@ -30,7 +30,7 @@ endif() # We require a C++17 compliant compiler set(CMAKE_CXX_STANDARD_REQUIRED ON) -if(NOT DEFINED CMAKE_BUILD_TYPE) +if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() From 53b15b61691dde1ea86e14b7a2216fa0a26f8054 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:17:29 -0400 Subject: [PATCH 22/32] MINOR: [Go] Fix Flakey TestRowsPrematureCloseDuringNextLoop Test (#43804) ### Rationale for this change Fixes a race condition in rows initialization that has been causing intermittent test failures. ### What changes are included in this PR? Split query and init context. Update test to check for failure _after_ reading rows. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/arrow/flight/flightsql/driver/driver.go | 10 ++++++---- go/arrow/flight/flightsql/driver/driver_test.go | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/go/arrow/flight/flightsql/driver/driver.go b/go/arrow/flight/flightsql/driver/driver.go index 0f2b02deaca7c..0513fe1ecd346 100644 --- a/go/arrow/flight/flightsql/driver/driver.go +++ b/go/arrow/flight/flightsql/driver/driver.go @@ -266,13 +266,14 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, err } + execCtx := ctx if _, set := ctx.Deadline(); !set && s.timeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, s.timeout) + execCtx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } - info, err := s.stmt.Execute(ctx) + info, err := s.stmt.Execute(execCtx) if err != nil { return nil, err } @@ -497,13 +498,14 @@ func (c *Connection) QueryContext(ctx context.Context, query string, args []driv return nil, driver.ErrSkip } + execCtx := ctx if _, set := ctx.Deadline(); !set && c.timeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, c.timeout) + execCtx, cancel = context.WithTimeout(ctx, c.timeout) defer cancel() } - info, err := c.client.Execute(ctx, query) + info, err := c.client.Execute(execCtx, query) if err != nil { return nil, err } diff --git a/go/arrow/flight/flightsql/driver/driver_test.go b/go/arrow/flight/flightsql/driver/driver_test.go index e5060ccbe33d0..c00dfe3c5d9a0 100644 --- a/go/arrow/flight/flightsql/driver/driver_test.go +++ b/go/arrow/flight/flightsql/driver/driver_test.go @@ -626,7 +626,6 @@ func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoop() { rows, err := db.QueryContext(context.TODO(), sqlSelectAll) require.NoError(t, err) require.NotNil(t, rows) - require.NoError(t, rows.Err()) const closeAfterNRows = 10 var ( @@ -645,6 +644,7 @@ func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoop() { require.NoError(t, rows.Close()) } } + require.NoError(t, rows.Err()) require.Equal(t, closeAfterNRows, i) From cb645a1b27dd66fddb88458c939e2851f9dadf35 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 24 Aug 2024 06:08:18 +0900 Subject: [PATCH 23/32] GH-43802: [GLib] Add `GAFlightRecordBatchWriter` (#43803) ### Rationale for this change This is needed to implement `DoPut`. ### What changes are included in this PR? We can't add tests for it because it's an abstract class. I'm not sure `is_owner` is needed like `GAFlightRecordBatchReader`. `is_owner` may be removed later if we find that it's needless. ### Are these changes tested? No. ### Are there any user-facing changes? Yes. `GAFlightRecordBatchWriter` is a new public API. * GitHub Issue: #43802 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-flight-glib/common.cpp | 198 ++++++++++++++++++++++++++-- c_glib/arrow-flight-glib/common.h | 32 +++++ c_glib/arrow-flight-glib/common.hpp | 4 + 3 files changed, 224 insertions(+), 10 deletions(-) diff --git a/c_glib/arrow-flight-glib/common.cpp b/c_glib/arrow-flight-glib/common.cpp index efc544f10cf66..f7eea08c264b3 100644 --- a/c_glib/arrow-flight-glib/common.cpp +++ b/c_glib/arrow-flight-glib/common.cpp @@ -48,7 +48,11 @@ G_BEGIN_DECLS * * #GAFlightStreamChunk is a class for a chunk in stream. * - * #GAFlightRecordBatchReader is a class for reading record batches. + * #GAFlightRecordBatchReader is an abstract class for reading record + * batches with metadata. + * + * #GAFlightRecordBatchWeriter is an abstract class for + * writing record batches with metadata. * * Since: 5.0.0 */ @@ -1172,13 +1176,13 @@ typedef struct GAFlightRecordBatchReaderPrivate_ } GAFlightRecordBatchReaderPrivate; enum { - PROP_READER = 1, - PROP_IS_OWNER, + PROP_RECORD_BATCH_READER_READER = 1, + PROP_RECORD_BATCH_READER_IS_OWNER, }; -G_DEFINE_TYPE_WITH_PRIVATE(GAFlightRecordBatchReader, - gaflight_record_batch_reader, - G_TYPE_OBJECT) +G_DEFINE_ABSTRACT_TYPE_WITH_PRIVATE(GAFlightRecordBatchReader, + gaflight_record_batch_reader, + G_TYPE_OBJECT) #define GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(obj) \ static_cast( \ @@ -1204,11 +1208,11 @@ gaflight_record_batch_reader_set_property(GObject *object, auto priv = GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(object); switch (prop_id) { - case PROP_READER: + case PROP_RECORD_BATCH_READER_READER: priv->reader = static_cast(g_value_get_pointer(value)); break; - case PROP_IS_OWNER: + case PROP_RECORD_BATCH_READER_IS_OWNER: priv->is_owner = g_value_get_boolean(value); break; default: @@ -1236,7 +1240,7 @@ gaflight_record_batch_reader_class_init(GAFlightRecordBatchReaderClass *klass) nullptr, nullptr, static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_READER, spec); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_READER_READER, spec); spec = g_param_spec_boolean( "is-owner", @@ -1244,7 +1248,7 @@ gaflight_record_batch_reader_class_init(GAFlightRecordBatchReaderClass *klass) nullptr, TRUE, static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_IS_OWNER, spec); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_READER_IS_OWNER, spec); } /** @@ -1296,6 +1300,173 @@ gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, GError } } +typedef struct GAFlightRecordBatchWriterPrivate_ +{ + arrow::flight::MetadataRecordBatchWriter *writer; + bool is_owner; +} GAFlightRecordBatchWriterPrivate; + +enum { + PROP_RECORD_BATCH_WRITER_WRITER = 1, + PROP_RECORD_BATCH_WRITER_IS_OWNER, +}; + +G_DEFINE_ABSTRACT_TYPE_WITH_PRIVATE(GAFlightRecordBatchWriter, + gaflight_record_batch_writer, + GARROW_TYPE_RECORD_BATCH_WRITER) + +#define GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object) \ + static_cast( \ + gaflight_record_batch_writer_get_instance_private( \ + GAFLIGHT_RECORD_BATCH_WRITER(object))) + +static void +gaflight_record_batch_writer_finalize(GObject *object) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); + if (priv->is_owner) { + delete priv->writer; + } + G_OBJECT_CLASS(gaflight_info_parent_class)->finalize(object); +} + +static void +gaflight_record_batch_writer_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_RECORD_BATCH_WRITER_WRITER: + priv->writer = + static_cast(g_value_get_pointer(value)); + break; + case PROP_RECORD_BATCH_WRITER_IS_OWNER: + priv->is_owner = g_value_get_boolean(value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gaflight_record_batch_writer_init(GAFlightRecordBatchWriter *object) +{ +} + +static void +gaflight_record_batch_writer_class_init(GAFlightRecordBatchWriterClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gaflight_record_batch_writer_finalize; + gobject_class->set_property = gaflight_record_batch_writer_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer( + "writer", + nullptr, + nullptr, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_WRITER, spec); + + spec = g_param_spec_boolean( + "is-owner", + nullptr, + nullptr, + TRUE, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_IS_OWNER, spec); +} + +/** + * gaflight_record_batch_writer_begin: + * @writer: A #GAFlightRecordBatchWriter. + * @schema: A #GArrowSchema. + * @options: (nullable): A #GArrowWriteOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Begins writing data with the given schema. Only used with + * `DoExchange`. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_begin(GAFlightRecordBatchWriter *writer, + GArrowSchema *schema, + GArrowWriteOptions *options, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_schema = garrow_schema_get_raw(schema); + arrow::ipc::IpcWriteOptions arrow_write_options; + if (options) { + arrow_write_options = *garrow_write_options_get_raw(options); + } else { + arrow_write_options = arrow::ipc::IpcWriteOptions::Defaults(); + } + return garrow::check(error, + flight_writer->Begin(arrow_schema, arrow_write_options), + "[flight-record-batch-writer][begin]"); +} + +/** + * gaflight_record_batch_writer_write_metadata: + * @writer: A #GAFlightRecordBatchWriter. + * @metadata: A #GArrowBuffer. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Write metadata. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, + GArrowBuffer *metadata, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_metadata = garrow_buffer_get_raw(metadata); + return garrow::check(error, + flight_writer->WriteMetadata(arrow_metadata), + "[flight-record-batch-writer][write-metadata]"); +} + +/** + * gaflight_record_batch_writer_write: + * @writer: A #GAFlightRecordBatchWriter. + * @record_batch: A #GArrowRecordBatch. + * @metadata: (nullable): A #GArrowBuffer. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Write a record batch with metadata. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); + auto arrow_metadata = garrow_buffer_get_raw(metadata); + return garrow::check( + error, + flight_writer->WriteWithMetadata(*arrow_record_batch, arrow_metadata), + "[flight-record-batch-writer][write]"); +} + G_END_DECLS GAFlightCriteria * @@ -1428,3 +1599,10 @@ gaflight_record_batch_reader_get_raw(GAFlightRecordBatchReader *reader) auto priv = GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(reader); return priv->reader; } + +arrow::flight::MetadataRecordBatchWriter * +gaflight_record_batch_writer_get_raw(GAFlightRecordBatchWriter *writer) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(writer); + return priv->writer; +} diff --git a/c_glib/arrow-flight-glib/common.h b/c_glib/arrow-flight-glib/common.h index b1d89f79c357e..91c828caabb36 100644 --- a/c_glib/arrow-flight-glib/common.h +++ b/c_glib/arrow-flight-glib/common.h @@ -232,4 +232,36 @@ GAFLIGHT_AVAILABLE_IN_6_0 GArrowTable * gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, GError **error); +#define GAFLIGHT_TYPE_RECORD_BATCH_WRITER (gaflight_record_batch_writer_get_type()) +GAFLIGHT_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE(GAFlightRecordBatchWriter, + gaflight_record_batch_writer, + GAFLIGHT, + RECORD_BATCH_WRITER, + GArrowRecordBatchWriter) +struct _GAFlightRecordBatchWriterClass +{ + GArrowRecordBatchWriterClass parent_class; +}; + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_begin(GAFlightRecordBatchWriter *writer, + GArrowSchema *schema, + GArrowWriteOptions *options, + GError **error); + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, + GArrowBuffer *metadata, + GError **error); + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error); + G_END_DECLS diff --git a/c_glib/arrow-flight-glib/common.hpp b/c_glib/arrow-flight-glib/common.hpp index db56fff579baf..ae5a7703397dd 100644 --- a/c_glib/arrow-flight-glib/common.hpp +++ b/c_glib/arrow-flight-glib/common.hpp @@ -79,3 +79,7 @@ gaflight_stream_chunk_get_raw(GAFlightStreamChunk *chunk); GAFLIGHT_EXTERN arrow::flight::MetadataRecordBatchReader * gaflight_record_batch_reader_get_raw(GAFlightRecordBatchReader *reader); + +GAFLIGHT_EXTERN +arrow::flight::MetadataRecordBatchWriter * +gaflight_record_batch_writer_get_raw(GAFlightRecordBatchWriter *writer); From 146b4e9669071984c883ec5791676638014bd655 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 24 Aug 2024 06:22:26 +0900 Subject: [PATCH 24/32] GH-43743: [CI][Docs] Ensure creating build directory (#43744) ### Rationale for this change It's used as a volume. If it doesn't exist, `docker compose` reports an error: Error response from daemon: invalid mount config for type "bind": bind source path does not exist: /home/runner/work/crossbow/crossbow/build/ ### What changes are included in this PR? * Create build directory * Move required `-v $PWD/build/:/build/` to `docs/github.linux.yml` ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: #43743 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- dev/tasks/docs/github.linux.yml | 4 +++- dev/tasks/tasks.yml | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/tasks/docs/github.linux.yml b/dev/tasks/docs/github.linux.yml index 8ab8a593c3ef3..5863d68d2c828 100644 --- a/dev/tasks/docs/github.linux.yml +++ b/dev/tasks/docs/github.linux.yml @@ -34,8 +34,10 @@ jobs: env: ARROW_JAVA_SKIP_GIT_PLUGIN: true run: | + mkdir -p build archery docker run \ -e SETUPTOOLS_SCM_PRETEND_VERSION="{{ arrow.no_rc_version }}" \ + -v $PWD/build/:/build/ \ {{ flags|default("") }} \ {{ image }} \ {{ command|default("") }} @@ -45,7 +47,7 @@ jobs: ref: {{ default_branch|default("main") }} path: crossbow fetch-depth: 1 - {% if publish %} + {% if publish %} - name: Prepare Docs Preview run: | # build files are created by the docker user diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 60114d6930878..cae34c3231381 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1487,7 +1487,7 @@ tasks: image: debian-go {% endfor %} - # be sure to update binary-task.rb when upgrading ubuntu + # be sure to update binary-task.rb when upgrading Debian test-debian-12-docs: ci: github template: docs/github.linux.yml @@ -1495,7 +1495,6 @@ tasks: env: JDK: 17 pr_number: Unset - flags: "-v $PWD/build/:/build/" image: debian-docs publish: false artifacts: @@ -1621,6 +1620,5 @@ tasks: env: JDK: 17 pr_number: Unset - flags: "-v $PWD/build/:/build/" image: debian-docs publish: true From e61c105c73dfabb51d5afc972ff21cc5326b3d93 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Sat, 24 Aug 2024 07:07:09 +0530 Subject: [PATCH 25/32] GH-41584: [Java] ListView Implementation for C Data Interface (#43686) ### Rationale for this change C Data Interface is missing `ListView` and `LargeListView` after recently merging core functionalities. Also closes; - [x] https://github.com/apache/arrow/issues/41585 ### What changes are included in this PR? This PR includes C Data interface related component additions to `ListView` and `LargeListView` along with the corresponding test cases. ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #41584 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- dev/archery/archery/integration/datagen.py | 1 - .../arrow/c/BufferImportTypeVisitor.java | 14 +- .../main/java/org/apache/arrow/c/Format.java | 8 ++ .../org/apache/arrow/c/RoundtripTest.java | 42 ++++++ java/c/src/test/python/integration_tests.py | 47 ++++++ .../BaseLargeRepeatedValueViewVector.java | 29 ++-- .../complex/BaseRepeatedValueViewVector.java | 30 ++-- .../vector/complex/LargeListViewVector.java | 10 +- .../arrow/vector/complex/ListViewVector.java | 6 +- .../arrow/vector/TestLargeListViewVector.java | 134 ++++++++++++++++++ .../arrow/vector/TestListViewVector.java | 132 +++++++++++++++++ .../testing/ValueVectorDataPopulator.java | 34 +++++ 12 files changed, 451 insertions(+), 36 deletions(-) diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 47310c905a9ff..d395d26cb71d3 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1936,7 +1936,6 @@ def _temp_path(): generate_list_view_case() .skip_tester('C#') # Doesn't support large list views - .skip_tester('Java') .skip_tester('JS') .skip_tester('nanoarrow') .skip_tester('Rust'), diff --git a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java index 633ecd43bd570..93fef6d7ca801 100644 --- a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java +++ b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java @@ -47,7 +47,9 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -400,13 +402,17 @@ public List visit(ArrowType.Duration type) { @Override public List visit(ArrowType.ListView type) { - throw new UnsupportedOperationException( - "Importing buffers for view type: " + type + " not supported"); + return Arrays.asList( + maybeImportBitmap(type), + importFixedBytes(type, 1, ListViewVector.OFFSET_WIDTH), + importFixedBytes(type, 2, ListViewVector.SIZE_WIDTH)); } @Override public List visit(ArrowType.LargeListView type) { - throw new UnsupportedOperationException( - "Importing buffers for view type: " + type + " not supported"); + return Arrays.asList( + maybeImportBitmap(type), + importFixedBytes(type, 1, LargeListViewVector.OFFSET_WIDTH), + importFixedBytes(type, 2, LargeListViewVector.SIZE_WIDTH)); } } diff --git a/java/c/src/main/java/org/apache/arrow/c/Format.java b/java/c/src/main/java/org/apache/arrow/c/Format.java index aff51e7b734ab..f77a555d18481 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Format.java +++ b/java/c/src/main/java/org/apache/arrow/c/Format.java @@ -229,6 +229,10 @@ static String asString(ArrowType arrowType) { return "vu"; case BinaryView: return "vz"; + case ListView: + return "+vl"; + case LargeListView: + return "+vL"; case NONE: throw new IllegalArgumentException("Arrow type ID is NONE"); default: @@ -313,6 +317,10 @@ static ArrowType asType(String format, long flags) return new ArrowType.Utf8View(); case "vz": return new ArrowType.BinaryView(); + case "+vl": + return new ArrowType.ListView(); + case "+vL": + return new ArrowType.LargeListView(); default: String[] parts = format.split(":", 2); if (parts.length == 2) { diff --git a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java index 6591d1f730990..18b2e94adde47 100644 --- a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java @@ -84,7 +84,9 @@ import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; @@ -683,6 +685,46 @@ public void testFixedSizeListVector() { } } + @Test + public void testListViewVector() { + try (final ListViewVector vector = ListViewVector.empty("v", allocator)) { + setVector( + vector, + Arrays.stream(new int[] {1, 2}).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] {3, 4}).boxed().collect(Collectors.toList()), + new ArrayList()); + assertTrue(roundtrip(vector, ListViewVector.class)); + } + } + + @Test + public void testEmptyListViewVector() { + try (final ListViewVector vector = ListViewVector.empty("v", allocator)) { + setVector(vector, new ArrayList()); + assertTrue(roundtrip(vector, ListViewVector.class)); + } + } + + @Test + public void testLargeListViewVector() { + try (final LargeListViewVector vector = LargeListViewVector.empty("v", allocator)) { + setVector( + vector, + Arrays.stream(new int[] {1, 2}).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] {3, 4}).boxed().collect(Collectors.toList()), + new ArrayList()); + assertTrue(roundtrip(vector, LargeListViewVector.class)); + } + } + + @Test + public void testEmptyLargeListViewVector() { + try (final LargeListViewVector vector = LargeListViewVector.empty("v", allocator)) { + setVector(vector, new ArrayList()); + assertTrue(roundtrip(vector, LargeListViewVector.class)); + } + } + @Test public void testMapVector() { int count = 5; diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index ab2ee1742f366..b0a86e9c66e59 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -352,6 +352,53 @@ def test_reader_complex_roundtrip(self): ] self.round_trip_reader(schema, data) + def test_listview_array(self): + self.round_trip_array(lambda: pa.array( + [[], [0], [1, 2], [4, 5, 6]], pa.list_view(pa.int64()) + # disabled check_metadata since in Java API the listview + # internal field name ("item") is not preserved + # during round trips (it becomes "$data$"). + ), check_metadata=False) + + def test_empty_listview_array(self): + with pa.BufferOutputStream() as bos: + schema = pa.schema([pa.field("f0", pa.list_view(pa.int32()), True)]) + with ipc.new_stream(bos, schema) as writer: + src = pa.RecordBatch.from_arrays( + [pa.array([[]], pa.list_view(pa.int32()))], schema=schema) + writer.write(src) + data_bytes = bos.getvalue() + + def recreate_batch(): + with pa.input_stream(data_bytes) as ios: + with ipc.open_stream(ios) as reader: + return reader.read_next_batch() + + self.round_trip_record_batch(recreate_batch) + + def test_largelistview_array(self): + self.round_trip_array(lambda: pa.array( + [[], [0], [1, 2], [4, 5, 6]], pa.large_list_view(pa.int64()) + # disabled check_metadata since in Java API the listview + # internal field name ("item") is not preserved + # during round trips (it becomes "$data$"). + ), check_metadata=False) + + def test_empty_largelistview_array(self): + with pa.BufferOutputStream() as bos: + schema = pa.schema([pa.field("f0", pa.large_list_view(pa.int32()), True)]) + with ipc.new_stream(bos, schema) as writer: + src = pa.RecordBatch.from_arrays( + [pa.array([[]], pa.large_list_view(pa.int32()))], schema=schema) + writer.write(src) + data_bytes = bos.getvalue() + + def recreate_batch(): + with pa.input_stream(data_bytes) as ios: + with ipc.open_stream(ios) as reader: + return reader.read_next_batch() + + self.round_trip_record_batch(recreate_batch) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java index f643306cfdcff..12edd6557bd9c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java @@ -305,38 +305,43 @@ public void setValueCount(int valueCount) { while (valueCount > getOffsetBufferValueCapacity()) { reallocateBuffers(); } - final int childValueCount = valueCount == 0 ? 0 : getLengthOfChildVector(); + final int childValueCount = valueCount == 0 ? 0 : getMaxViewEndChildVector(); vector.setValueCount(childValueCount); } - protected int getLengthOfChildVector() { + /** + * Get the end of the child vector via the maximum view length. This method deduces the length by + * considering the condition i.e., argmax_i(offsets[i] + size[i]). + * + * @return the end of the child vector. + */ + protected int getMaxViewEndChildVector() { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < valueCount; i++) { int currentOffset = offsetBuffer.getInt((long) i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt((long) i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } - protected int getLengthOfChildVectorByIndex(int index) { + /** + * Get the end of the child vector via the maximum view length of the child vector by index. + * + * @return the end of the child vector by index + */ + protected int getMaxViewEndChildVectorByIndex(int index) { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < index; i++) { int currentOffset = offsetBuffer.getInt((long) i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt((long) i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } /** @@ -390,7 +395,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt((long) index * OFFSET_WIDTH, prevOffset); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java index 031cc8037bb8b..e6213316b55a3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java @@ -304,38 +304,44 @@ public void setValueCount(int valueCount) { while (valueCount > getOffsetBufferValueCapacity()) { reallocateBuffers(); } - final int childValueCount = valueCount == 0 ? 0 : getLengthOfChildVector(); + final int childValueCount = valueCount == 0 ? 0 : getMaxViewEndChildVector(); vector.setValueCount(childValueCount); } - protected int getLengthOfChildVector() { + /** + * Get the end of the child vector via the maximum view length. This method deduces the length by + * considering the condition i.e., argmax_i(offsets[i] + size[i]). + * + * @return the end of the child vector. + */ + protected int getMaxViewEndChildVector() { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < valueCount; i++) { int currentOffset = offsetBuffer.getInt(i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt(i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } - protected int getLengthOfChildVectorByIndex(int index) { + /** + * Get the end of the child vector via the maximum view length of the child vector by index. + * + * @return the end of the child vector by index + */ + protected int getMaxViewEndChildVectorByIndex(int index) { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); + // int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < index; i++) { int currentOffset = offsetBuffer.getInt(i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt(i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } /** @@ -389,7 +395,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 2c61f799a4cf9..84c6f03edb25d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -250,7 +250,9 @@ public List getFieldBuffers() { */ @Override public void exportCDataBuffers(List buffers, ArrowBuf buffersPtr, long nullValue) { - throw new UnsupportedOperationException("exportCDataBuffers Not implemented yet"); + exportBuffer(validityBuffer, buffers, buffersPtr, nullValue, true); + exportBuffer(offsetBuffer, buffers, buffersPtr, nullValue, true); + exportBuffer(sizeBuffer, buffers, buffersPtr, nullValue, true); } @Override @@ -851,7 +853,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } @@ -943,7 +945,7 @@ public void setValueCount(int valueCount) { } } /* valueCount for the data vector is the current end offset */ - final long childValueCount = (valueCount == 0) ? 0 : getLengthOfChildVector(); + final long childValueCount = (valueCount == 0) ? 0 : getMaxViewEndChildVector(); /* set the value count of data vector and this will take care of * checking whether data buffer needs to be reallocated. * TODO: revisit when 64-bit vectors are supported @@ -1001,7 +1003,7 @@ public double getDensity() { if (valueCount == 0) { return 0.0D; } - final double totalListSize = getLengthOfChildVector(); + final double totalListSize = getMaxViewEndChildVector(); return totalListSize / valueCount; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 7f6d92f3be9c8..9b4e6b4c0cd4a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -858,7 +858,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } @@ -942,7 +942,7 @@ public void setValueCount(int valueCount) { } } /* valueCount for the data vector is the current end offset */ - final int childValueCount = (valueCount == 0) ? 0 : getLengthOfChildVector(); + final int childValueCount = (valueCount == 0) ? 0 : getMaxViewEndChildVector(); /* set the value count of data vector and this will take care of * checking whether data buffer needs to be reallocated. */ @@ -1005,7 +1005,7 @@ public double getDensity() { if (valueCount == 0) { return 0.0D; } - final double totalListSize = getLengthOfChildVector(); + final double totalListSize = getMaxViewEndChildVector(); return totalListSize / valueCount; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java index 2ed8d4d7005ea..26e7bb4a0d3b2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java @@ -2095,6 +2095,140 @@ public void testOutOfOrderOffsetSplitAndTransfer() { } } + @Test + public void testRangeChildVector1() { + /* + * Non-overlapping ranges + * offsets: [0, 2] + * sizes: [4, 1] + * values: [0, 1, 2, 3] + * + * vector: [[0, 1, 2, 3], [2]] + * */ + try (LargeListViewVector largeListViewVector = + LargeListViewVector.empty("largelistview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + largeListViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + largeListViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = largeListViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + largeListViewVector.setValidity(0, 1); + largeListViewVector.setValidity(1, 1); + + largeListViewVector.setOffset(0, 0); + largeListViewVector.setOffset(1, 2); + + largeListViewVector.setSize(0, 4); + largeListViewVector.setSize(1, 1); + + assertEquals(8, largeListViewVector.getDataVector().getValueCount()); + + largeListViewVector.setValueCount(2); + assertEquals(4, largeListViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) largeListViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + + @Test + public void testRangeChildVector2() { + /* + * Overlapping ranges + * offsets: [0, 2] + * sizes: [3, 1] + * values: [0, 1, 2, 3] + * + * vector: [[1, 2, 3], [2]] + * */ + try (LargeListViewVector largeListViewVector = + LargeListViewVector.empty("largelistview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + largeListViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + largeListViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = largeListViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + largeListViewVector.setValidity(0, 1); + largeListViewVector.setValidity(1, 1); + + largeListViewVector.setOffset(0, 1); + largeListViewVector.setOffset(1, 2); + + largeListViewVector.setSize(0, 3); + largeListViewVector.setSize(1, 1); + + assertEquals(8, largeListViewVector.getDataVector().getValueCount()); + + largeListViewVector.setValueCount(2); + assertEquals(4, largeListViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) largeListViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + private void writeIntValues(UnionLargeListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java index 4fa808c18aece..639585fc48d0a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java @@ -2084,6 +2084,138 @@ public void testOutOfOrderOffsetSplitAndTransfer() { } } + @Test + public void testRangeChildVector1() { + /* + * Non-overlapping ranges + * offsets: [0, 2] + * sizes: [4, 1] + * values: [0, 1, 2, 3] + * + * vector: [[0, 1, 2, 3], [2]] + * */ + try (ListViewVector listViewVector = ListViewVector.empty("listview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + listViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + listViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = listViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + listViewVector.setValidity(0, 1); + listViewVector.setValidity(1, 1); + + listViewVector.setOffset(0, 0); + listViewVector.setOffset(1, 2); + + listViewVector.setSize(0, 4); + listViewVector.setSize(1, 1); + + assertEquals(8, listViewVector.getDataVector().getValueCount()); + + listViewVector.setValueCount(2); + assertEquals(4, listViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) listViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + + @Test + public void testRangeChildVector2() { + /* + * Overlapping ranges + * offsets: [0, 2] + * sizes: [3, 1] + * values: [0, 1, 2, 3] + * + * vector: [[1, 2, 3], [2]] + * */ + try (ListViewVector listViewVector = ListViewVector.empty("listview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + listViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + listViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = listViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + listViewVector.setValidity(0, 1); + listViewVector.setValidity(1, 1); + + listViewVector.setOffset(0, 1); + listViewVector.setOffset(1, 2); + + listViewVector.setSize(0, 3); + listViewVector.setSize(1, 1); + + assertEquals(8, listViewVector.getDataVector().getValueCount()); + + listViewVector.setValueCount(2); + assertEquals(4, listViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) listViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + private void writeIntValues(UnionListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java index 69e16dc470351..afbc30f019ef6 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java @@ -60,10 +60,12 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VariableWidthFieldVector; +import org.apache.arrow.vector.complex.BaseLargeRepeatedValueViewVector; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.StructVector; @@ -760,4 +762,36 @@ public static void setVector(ListViewVector vector, List... values) { dataVector.setValueCount(curPos); vector.setValueCount(values.length); } + + /** Populate values for {@link ListViewVector}. */ + public static void setVector(LargeListViewVector vector, List... values) { + vector.allocateNewSafe(); + Types.MinorType type = Types.MinorType.INT; + vector.addOrGetVector(FieldType.nullable(type.getType())); + + IntVector dataVector = (IntVector) vector.getDataVector(); + dataVector.allocateNew(); + + // set underlying vectors + int curPos = 0; + for (int i = 0; i < values.length; i++) { + vector + .getOffsetBuffer() + .setInt((long) i * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH, curPos); + if (values[i] == null) { + BitVectorHelper.unsetBit(vector.getValidityBuffer(), i); + } else { + BitVectorHelper.setBit(vector.getValidityBuffer(), i); + for (int value : values[i]) { + dataVector.setSafe(curPos, value); + curPos += 1; + } + } + vector + .getSizeBuffer() + .setInt((long) i * BaseRepeatedValueViewVector.SIZE_WIDTH, values[i].size()); + } + dataVector.setValueCount(curPos); + vector.setValueCount(values.length); + } } From 83d915a3d2ac2acecbb2cb2dc0dd7f5a213dd625 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:38:38 +0900 Subject: [PATCH 26/32] MINOR: [Java] Bump dep.slf4j.version from 2.0.13 to 2.0.16 in /java (#43652) Bumps `dep.slf4j.version` from 2.0.13 to 2.0.16. Updates `org.slf4j:slf4j-api` from 2.0.13 to 2.0.16 Updates `org.slf4j:slf4j-jdk14` from 2.0.13 to 2.0.16 Updates `org.slf4j:jul-to-slf4j` from 2.0.13 to 2.0.16 Updates `org.slf4j:jcl-over-slf4j` from 2.0.13 to 2.0.16 Updates `org.slf4j:log4j-over-slf4j` from 2.0.13 to 2.0.16 Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index a73453df68fd2..54bb7a0ae0eb9 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -94,7 +94,7 @@ under the License. ${project.build.directory}/generated-sources 1.9.0 5.10.3 - 2.0.13 + 2.0.16 33.2.1-jre 4.1.112.Final 1.66.0 From cbb5f96306972aa236750602aba4b40ceb4219c4 Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Sun, 25 Aug 2024 21:33:51 -0700 Subject: [PATCH 27/32] MINOR: [R] Add missing PR num to news.md item (#43811) ### Rationale for this change We normally link to somewhere to give the user more context on news items. I noticed the link was missing for this one. ### What changes are included in this PR? Added PR number to news item. ### Are these changes tested? No. ### Are there any user-facing changes? No. Authored-by: Bryce Mecum Signed-off-by: Jacob Wujciak-Jens --- r/NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/NEWS.md b/r/NEWS.md index 0e6e4634a0af8..b9568afe66542 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -32,7 +32,7 @@ functions (UDFs); for UDFs, see `register_scalar_function()`. (#41223) * `mutate()` expressions can now include aggregations, such as `x - mean(x)`. (#41350) * `summarize()` supports more complex expressions, and correctly handles cases - where column names are reused in expressions. + where column names are reused in expressions. (#41223) * The `na_matches` argument to the `dplyr::*_join()` functions is now supported. This argument controls whether `NA` values are considered equal when joining. (#41358) * R metadata, stored in the Arrow schema to support round-tripping data between From 51e9f70f94cd09a0a08196afdd2f4fc644666b5e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:20:20 +0900 Subject: [PATCH 28/32] MINOR: [Java] Bump dep.junit.jupiter.version from 5.10.3 to 5.11.0 in /java (#43751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps `dep.junit.jupiter.version` from 5.10.3 to 5.11.0. Updates `org.junit.jupiter:junit-jupiter-engine` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-engine's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Updates `org.junit.jupiter:junit-jupiter-api` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-api's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Updates `org.junit.jupiter:junit-jupiter-params` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-params's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 54bb7a0ae0eb9..77feed12f3f1d 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -93,7 +93,7 @@ under the License. ${project.build.directory}/generated-sources 1.9.0 - 5.10.3 + 5.11.0 2.0.16 33.2.1-jre 4.1.112.Final From 2328b6ee39b497d9f48e6d342db9f7d0c34d9791 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 26 Aug 2024 16:34:18 +0200 Subject: [PATCH 29/32] GH-15058: [C++][Python] Native support for UUID (#37298) ### Rationale for this change See #15058. UUID datatype is common in throughout the ecosystem and Arrow as supporting it as a native type would reduce friction. ### What changes are included in this PR? This PR implements logic for Arrow canonical extension type in C++ and a Python wrapper. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes, new extension type is added. * Closes: #15058 Authored-by: Rok Mihevc Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 3 +- cpp/src/arrow/acero/hash_join_node_test.cc | 1 + cpp/src/arrow/extension/CMakeLists.txt | 2 +- .../extension/fixed_shape_tensor_test.cc | 17 +-- cpp/src/arrow/extension/uuid.cc | 58 ++++++++++ cpp/src/arrow/extension/uuid.h | 61 ++++++++++ cpp/src/arrow/extension/uuid_test.cc | 72 ++++++++++++ cpp/src/arrow/extension_type.cc | 4 +- cpp/src/arrow/extension_type_test.cc | 19 +--- .../integration/json_integration_test.cc | 2 +- cpp/src/arrow/ipc/test_common.cc | 35 ++++-- cpp/src/arrow/ipc/test_common.h | 3 + cpp/src/arrow/scalar_test.cc | 5 +- cpp/src/arrow/testing/extension_type.h | 6 +- cpp/src/arrow/testing/gtest_util.cc | 16 ++- dev/archery/archery/integration/datagen.py | 2 +- docs/source/format/CanonicalExtensions.rst | 2 + docs/source/status.rst | 2 +- python/pyarrow/__init__.py | 18 +-- python/pyarrow/array.pxi | 6 + python/pyarrow/includes/libarrow.pxd | 10 ++ python/pyarrow/lib.pxd | 3 + python/pyarrow/public-api.pxi | 11 +- python/pyarrow/scalar.pxi | 10 ++ python/pyarrow/src/arrow/python/gdb.cc | 27 +---- python/pyarrow/tests/extensions.pyx | 2 +- python/pyarrow/tests/test_extension_type.py | 105 ++++++++++++------ python/pyarrow/tests/test_gdb.py | 8 +- python/pyarrow/types.pxi | 34 ++++++ 29 files changed, 412 insertions(+), 132 deletions(-) create mode 100644 cpp/src/arrow/extension/uuid.cc create mode 100644 cpp/src/arrow/extension/uuid.h create mode 100644 cpp/src/arrow/extension/uuid_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 89f28ee416ede..6b0ac8c23c75a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -375,6 +375,7 @@ set(ARROW_SRCS device.cc extension_type.cc extension/bool8.cc + extension/uuid.cc pretty_print.cc record_batch.cc result.cc @@ -1225,6 +1226,7 @@ add_subdirectory(testing) add_subdirectory(array) add_subdirectory(c) add_subdirectory(compute) +add_subdirectory(extension) add_subdirectory(io) add_subdirectory(tensor) add_subdirectory(util) @@ -1267,7 +1269,6 @@ endif() if(ARROW_JSON) add_subdirectory(json) - add_subdirectory(extension) endif() if(ARROW_ORC) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 9065e286a2228..76ad9c7d650eb 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -29,6 +29,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/row/row_encoder_internal.h" +#include "arrow/extension/uuid.h" #include "arrow/testing/extension_type.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 5cb4bc77af2a4..065ea3f1ddb16 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -set(CANONICAL_EXTENSION_TESTS bool8_test.cc) +set(CANONICAL_EXTENSION_TESTS bool8_test.cc uuid_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc index 3fd39a11ff50d..842a78e1a4f7a 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc @@ -23,7 +23,7 @@ #include "arrow/array/array_primitive.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" +#include "arrow/ipc/test_common.h" #include "arrow/record_batch.h" #include "arrow/tensor.h" #include "arrow/testing/gtest_util.h" @@ -33,6 +33,7 @@ namespace arrow { using FixedShapeTensorType = extension::FixedShapeTensorType; +using arrow::ipc::test::RoundtripBatch; using extension::fixed_shape_tensor; using extension::FixedShapeTensorArray; @@ -71,20 +72,6 @@ class TestExtensionType : public ::testing::Test { std::string serialized_; }; -auto RoundtripBatch = [](const std::shared_ptr& batch, - std::shared_ptr* out) { - ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), - out_stream.get())); - - ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); - - io::BufferReader reader(complete_ipc_stream); - std::shared_ptr batch_reader; - ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); - ASSERT_OK(batch_reader->ReadNext(out)); -}; - TEST_F(TestExtensionType, CheckDummyRegistration) { // We need a registered dummy type at runtime to allow for IPC deserialization auto registered_type = GetExtensionType("arrow.fixed_shape_tensor"); diff --git a/cpp/src/arrow/extension/uuid.cc b/cpp/src/arrow/extension/uuid.cc new file mode 100644 index 0000000000000..43b917a17f8b2 --- /dev/null +++ b/cpp/src/arrow/extension/uuid.cc @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/extension_type.h" +#include "arrow/util/logging.h" + +#include "arrow/extension/uuid.h" + +namespace arrow::extension { + +bool UuidType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr UuidType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.uuid", + static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> UuidType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (!serialized.empty()) { + return Status::Invalid("Unexpected serialized metadata: '", serialized, "'"); + } + if (!storage_type->Equals(*fixed_size_binary(16))) { + return Status::Invalid("Invalid storage type for UuidType: ", + storage_type->ToString()); + } + return std::make_shared(); +} + +std::string UuidType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() << ">"; + return ss.str(); +} + +std::shared_ptr uuid() { return std::make_shared(); } + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/uuid.h b/cpp/src/arrow/extension/uuid.h new file mode 100644 index 0000000000000..42bb21cf0b2ed --- /dev/null +++ b/cpp/src/arrow/extension/uuid.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/extension_type.h" + +namespace arrow::extension { + +/// \brief UuidArray stores array of UUIDs. Underlying storage type is +/// FixedSizeBinary(16). +class ARROW_EXPORT UuidArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief UuidType is a canonical arrow extension type for UUIDs. +/// UUIDs are stored as FixedSizeBinary(16) with big-endian notation and this +/// does not interpret the bytes in any way. Specific UUID version is not +/// required or guaranteed. +class ARROW_EXPORT UuidType : public ExtensionType { + public: + /// \brief Construct a UuidType. + UuidType() : ExtensionType(fixed_size_binary(16)) {} + + std::string extension_name() const override { return "arrow.uuid"; } + std::string ToString(bool show_metadata = false) const override; + + bool ExtensionEquals(const ExtensionType& other) const override; + + /// Create a UuidArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return ""; } + + /// \brief Create a UuidType instance + static Result> Make() { return std::make_shared(); } +}; + +/// \brief Return a UuidType instance. +ARROW_EXPORT std::shared_ptr uuid(); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/uuid_test.cc b/cpp/src/arrow/extension/uuid_test.cc new file mode 100644 index 0000000000000..3bbb6eeb4aef1 --- /dev/null +++ b/cpp/src/arrow/extension/uuid_test.cc @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/uuid.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/test_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +#include "arrow/testing/extension_type.h" + +namespace arrow { + +using arrow::ipc::test::RoundtripBatch; + +TEST(TestUuuidExtensionType, ExtensionTypeTest) { + auto type = uuid(); + ASSERT_EQ(type->id(), Type::EXTENSION); + + const auto& ext_type = static_cast(*type); + std::string serialized = ext_type.Serialize(); + + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type.Deserialize(fixed_size_binary(16), serialized)); + ASSERT_TRUE(deserialized->Equals(*type)); + ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16))); +} + +TEST(TestUuuidExtensionType, RoundtripBatch) { + auto ext_type = extension::uuid(); + auto exact_ext_type = internal::checked_pointer_cast(ext_type); + auto arr = ArrayFromJSON(fixed_size_binary(16), R"(["abcdefghijklmnop", null])"); + auto ext_arr = ExtensionType::WrapArray(ext_type, arr); + + // Pass extension array, expect getting back extension array + std::shared_ptr read_batch; + auto ext_field = field(/*name=*/"f0", /*type=*/ext_type); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); + + // Pass extension metadata and storage array, expect getting back extension array + std::shared_ptr read_batch2; + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", ""}}); + ext_field = field(/*name=*/"f0", /*type=*/exact_ext_type->storage_type(), + /*nullable=*/true, /*metadata=*/ext_metadata); + auto batch2 = RecordBatch::Make(schema({ext_field}), arr->length(), {arr}); + RoundtripBatch(batch2, &read_batch2); + CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); +} + +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 83c7ebed4f319..fc220f73a6beb 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -32,6 +32,7 @@ #include "arrow/extension/fixed_shape_tensor.h" #include "arrow/extension/opaque.h" #endif +#include "arrow/extension/uuid.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -147,14 +148,13 @@ static void CreateGlobalRegistry() { // Register canonical extension types g_registry = std::make_shared(); - std::vector> ext_types{extension::bool8()}; + std::vector> ext_types{extension::bool8(), extension::uuid()}; #ifdef ARROW_JSON ext_types.push_back(extension::fixed_shape_tensor(int64(), {})); ext_types.push_back(extension::opaque(null(), "", "")); #endif - // Register canonical extension types for (const auto& ext_type : ext_types) { ARROW_CHECK_OK( g_registry->RegisterType(checked_pointer_cast(ext_type))); diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index f104c984a64b4..f49ffc5cba553 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -30,6 +30,7 @@ #include "arrow/io/memory.h" #include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/test_common.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -41,6 +42,8 @@ namespace arrow { +using arrow::ipc::test::RoundtripBatch; + class Parametric1Array : public ExtensionArray { public: using ExtensionArray::ExtensionArray; @@ -178,7 +181,7 @@ class ExtStructType : public ExtensionType { class TestExtensionType : public ::testing::Test { public: - void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared())); } + void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared())); } void TearDown() { if (GetExtensionType("uuid")) { @@ -211,20 +214,6 @@ TEST_F(TestExtensionType, ExtensionTypeTest) { ASSERT_EQ(deserialized->byte_width(), 16); } -auto RoundtripBatch = [](const std::shared_ptr& batch, - std::shared_ptr* out) { - ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), - out_stream.get())); - - ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); - - io::BufferReader reader(complete_ipc_stream); - std::shared_ptr batch_reader; - ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); - ASSERT_OK(batch_reader->ReadNext(out)); -}; - TEST_F(TestExtensionType, IpcRoundtrip) { auto ext_arr = ExampleUuid(); auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr}); diff --git a/cpp/src/arrow/integration/json_integration_test.cc b/cpp/src/arrow/integration/json_integration_test.cc index 9b56928c68843..0e84ea6124d5d 100644 --- a/cpp/src/arrow/integration/json_integration_test.cc +++ b/cpp/src/arrow/integration/json_integration_test.cc @@ -1046,7 +1046,7 @@ TEST(TestJsonFileReadWrite, JsonExample2) { auto storage_array = ArrayFromJSON(fixed_size_binary(16), R"(["0123456789abcdef", null])"); - AssertArraysEqual(*batch->column(0), UuidArray(uuid_type, storage_array)); + AssertArraysEqual(*batch->column(0), ExampleUuidArray(uuid_type, storage_array)); AssertArraysEqual(*batch->column(1), NullArray(2)); } diff --git a/cpp/src/arrow/ipc/test_common.cc b/cpp/src/arrow/ipc/test_common.cc index 87c02e2d87a1e..fb4f6bd8eadcf 100644 --- a/cpp/src/arrow/ipc/test_common.cc +++ b/cpp/src/arrow/ipc/test_common.cc @@ -27,8 +27,10 @@ #include "arrow/array.h" #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" -#include "arrow/array/builder_time.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/test_common.h" +#include "arrow/ipc/writer.h" #include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -242,11 +244,11 @@ Status MakeRandomBooleanArray(const int length, bool include_nulls, std::shared_ptr* out) { std::vector values(length); random_null_bytes(length, 0.5, values.data()); - ARROW_ASSIGN_OR_RAISE(auto data, internal::BytesToBits(values)); + ARROW_ASSIGN_OR_RAISE(auto data, arrow::internal::BytesToBits(values)); if (include_nulls) { std::vector valid_bytes(length); - ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(valid_bytes)); + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(valid_bytes)); random_null_bytes(length, 0.1, valid_bytes.data()); *out = std::make_shared(length, data, null_bitmap, -1); } else { @@ -596,7 +598,7 @@ Status MakeStruct(std::shared_ptr* out) { std::shared_ptr no_nulls(new StructArray(type, list_batch->num_rows(), columns)); std::vector null_bytes(list_batch->num_rows(), 1); null_bytes[0] = 0; - ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(null_bytes)); + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(null_bytes)); std::shared_ptr with_nulls( new StructArray(type, list_batch->num_rows(), columns, null_bitmap, 1)); @@ -1088,9 +1090,9 @@ Status MakeUuid(std::shared_ptr* out) { auto f1 = field("f1", uuid_type, /*nullable=*/false); auto schema = ::arrow::schema({f0, f1}); - auto a0 = std::make_shared( + auto a0 = std::make_shared( uuid_type, ArrayFromJSON(storage_type, R"(["0123456789abcdef", null])")); - auto a1 = std::make_shared( + auto a1 = std::make_shared( uuid_type, ArrayFromJSON(storage_type, R"(["ZYXWVUTSRQPONMLK", "JIHGFEDBA9876543"])")); @@ -1176,12 +1178,13 @@ enable_if_t::value, void> FillRandomData( Status MakeRandomTensor(const std::shared_ptr& type, const std::vector& shape, bool row_major_p, std::shared_ptr* out, uint32_t seed) { - const auto& element_type = internal::checked_cast(*type); + const auto& element_type = arrow::internal::checked_cast(*type); std::vector strides; if (row_major_p) { - RETURN_NOT_OK(internal::ComputeRowMajorStrides(element_type, shape, &strides)); + RETURN_NOT_OK(arrow::internal::ComputeRowMajorStrides(element_type, shape, &strides)); } else { - RETURN_NOT_OK(internal::ComputeColumnMajorStrides(element_type, shape, &strides)); + RETURN_NOT_OK( + arrow::internal::ComputeColumnMajorStrides(element_type, shape, &strides)); } const int64_t element_size = element_type.bit_width() / CHAR_BIT; @@ -1233,6 +1236,20 @@ Status MakeRandomTensor(const std::shared_ptr& type, return Tensor::Make(type, buf, shape, strides).Value(out); } +void RoundtripBatch(const std::shared_ptr& batch, + std::shared_ptr* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +} + } // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/test_common.h b/cpp/src/arrow/ipc/test_common.h index db8613cbb1e6a..9b7e7f13e3a8e 100644 --- a/cpp/src/arrow/ipc/test_common.h +++ b/cpp/src/arrow/ipc/test_common.h @@ -184,6 +184,9 @@ Status MakeRandomTensor(const std::shared_ptr& type, const std::vector& shape, bool row_major_p, std::shared_ptr* out, uint32_t seed = 0); +ARROW_TESTING_EXPORT void RoundtripBatch(const std::shared_ptr& batch, + std::shared_ptr* out); + } // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 104a5697b5727..e9ec13e98b4ee 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -43,7 +43,6 @@ namespace arrow { using compute::Cast; using compute::CastOptions; - using internal::checked_cast; using internal::checked_pointer_cast; @@ -2038,7 +2037,7 @@ class TestExtensionScalar : public ::testing::Test { void SetUp() { type_ = uuid(); storage_type_ = fixed_size_binary(16); - uuid_type_ = checked_cast(type_.get()); + uuid_type_ = checked_cast(type_.get()); } protected: @@ -2049,7 +2048,7 @@ class TestExtensionScalar : public ::testing::Test { } std::shared_ptr type_, storage_type_; - const UuidType* uuid_type_{nullptr}; + const ExampleUuidType* uuid_type_{nullptr}; const std::string_view uuid_string1_{UUID_STRING1}; const std::string_view uuid_string2_{UUID_STRING2}; diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 6515631f202ae..a4526e31c2b93 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -27,14 +27,14 @@ namespace arrow { -class ARROW_TESTING_EXPORT UuidArray : public ExtensionArray { +class ARROW_TESTING_EXPORT ExampleUuidArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; }; -class ARROW_TESTING_EXPORT UuidType : public ExtensionType { +class ARROW_TESTING_EXPORT ExampleUuidType : public ExtensionType { public: - UuidType() : ExtensionType(fixed_size_binary(16)) {} + ExampleUuidType() : ExtensionType(fixed_size_binary(16)) {} std::string extension_name() const override { return "uuid"; } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 95de16c715f19..ae2e53b30a3ee 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -49,9 +49,13 @@ #include "arrow/buffer.h" #include "arrow/compute/api_vector.h" #include "arrow/datum.h" +#include "arrow/io/memory.h" #include "arrow/ipc/json_simple.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep #include "arrow/pretty_print.h" +#include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/table.h" #include "arrow/tensor.h" @@ -847,17 +851,17 @@ Future<> SleepABitAsync() { /////////////////////////////////////////////////////////////////////////// // Extension types -bool UuidType::ExtensionEquals(const ExtensionType& other) const { +bool ExampleUuidType::ExtensionEquals(const ExtensionType& other) const { return (other.extension_name() == this->extension_name()); } -std::shared_ptr UuidType::MakeArray(std::shared_ptr data) const { +std::shared_ptr ExampleUuidType::MakeArray(std::shared_ptr data) const { DCHECK_EQ(data->type->id(), Type::EXTENSION); DCHECK_EQ("uuid", static_cast(*data->type).extension_name()); - return std::make_shared(data); + return std::make_shared(data); } -Result> UuidType::Deserialize( +Result> ExampleUuidType::Deserialize( std::shared_ptr storage_type, const std::string& serialized) const { if (serialized != "uuid-serialized") { return Status::Invalid("Type identifier did not match: '", serialized, "'"); @@ -866,7 +870,7 @@ Result> UuidType::Deserialize( return Status::Invalid("Invalid storage type for UuidType: ", storage_type->ToString()); } - return std::make_shared(); + return std::make_shared(); } bool SmallintType::ExtensionEquals(const ExtensionType& other) const { @@ -982,7 +986,7 @@ Result> Complex128Type::Deserialize( return std::make_shared(); } -std::shared_ptr uuid() { return std::make_shared(); } +std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index d395d26cb71d3..f63aa0d95a484 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1845,7 +1845,7 @@ def generate_nested_dictionary_case(): def generate_extension_case(): dict0 = Dictionary(0, StringField('dictionary0'), size=5, name='DICT0') - uuid_type = ExtensionType('uuid', 'uuid-serialized', + uuid_type = ExtensionType('arrow.uuid', '', FixedSizeBinaryField('', 16)) dict_ext_type = ExtensionType( 'dict-extension', 'dict-extension-serialized', diff --git a/docs/source/format/CanonicalExtensions.rst b/docs/source/format/CanonicalExtensions.rst index 5658f949ceeaa..1106f8aaffdd3 100644 --- a/docs/source/format/CanonicalExtensions.rst +++ b/docs/source/format/CanonicalExtensions.rst @@ -272,6 +272,8 @@ JSON In the future, additional fields may be added, but they are not required to interpret the array. +.. _uuid_extension: + UUID ==== diff --git a/docs/source/status.rst b/docs/source/status.rst index 5e2c2cc19c890..b685d4bbf8add 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -121,7 +121,7 @@ Data Types +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | JSON | | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| UUID | | | ✓ | | | | | | +| UUID | ✓ | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | 8-bit Boolean | ✓ | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 807bcdc315036..d31c93119b73a 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -172,9 +172,7 @@ def print_entry(label, value): union, sparse_union, dense_union, dictionary, run_end_encoded, - fixed_shape_tensor, - opaque, - bool8, + bool8, fixed_shape_tensor, opaque, uuid, field, type_for_alias, DataType, DictionaryType, StructType, @@ -184,8 +182,9 @@ def print_entry(label, value): TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, - RunEndEncodedType, FixedShapeTensorType, OpaqueType, - Bool8Type, PyExtensionType, UnknownExtensionType, + RunEndEncodedType, Bool8Type, FixedShapeTensorType, + OpaqueType, UuidType, + PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, KeyValueMetadata, @@ -218,8 +217,9 @@ def print_entry(label, value): Time32Array, Time64Array, DurationArray, MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, - RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, - Bool8Array, scalar, NA, _NULL as NULL, Scalar, + RunEndEncodedArray, Bool8Array, FixedShapeTensorArray, + OpaqueArray, UuidArray, + scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, @@ -235,8 +235,8 @@ def print_entry(label, value): StringScalar, LargeStringScalar, StringViewScalar, FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, - RunEndEncodedScalar, ExtensionScalar, - FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar) + RunEndEncodedScalar, Bool8Scalar, ExtensionScalar, + FixedShapeTensorScalar, OpaqueScalar, UuidScalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 77d6c9c06d2de..1587de0e6b744 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4338,6 +4338,12 @@ cdef class ExtensionArray(Array): return result +class UuidArray(ExtensionArray): + """ + Concrete class for Arrow arrays of UUID data type. + """ + + cdef class FixedShapeTensorArray(ExtensionArray): """ Concrete class for fixed shape tensor extension arrays. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 6f510cfc0c06c..c2346750a196f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2865,6 +2865,16 @@ cdef extern from "arrow/extension_type.h" namespace "arrow": shared_ptr[CArray] storage() +cdef extern from "arrow/extension/uuid.h" namespace "arrow::extension" nogil: + cdef cppclass CUuidType" arrow::extension::UuidType"(CExtensionType): + + @staticmethod + CResult[shared_ptr[CDataType]] Make() + + cdef cppclass CUuidArray" arrow::extension::UuidArray"(CExtensionArray): + pass + + cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension" nogil: cdef cppclass CFixedShapeTensorType \ " arrow::extension::FixedShapeTensorType"(CExtensionType): diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index a7c3b496a0045..5c3d981c3adc7 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -222,6 +222,9 @@ cdef class OpaqueType(BaseExtensionType): cdef: const COpaqueType* opaque_ext_type +cdef class UuidType(BaseExtensionType): + cdef: + const CUuidType* uuid_ext_type cdef class PyExtensionType(ExtensionType): pass diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 19a26bd6c683d..d3e2ff2e99d91 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -120,14 +120,17 @@ cdef api object pyarrow_wrap_data_type( elif type.get().id() == _Type_EXTENSION: ext_type = type.get() cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type) + extension_name = ext_type.extension_name() if cpy_ext_type != nullptr: return cpy_ext_type.GetInstance() - elif ext_type.extension_name() == b"arrow.fixed_shape_tensor": + elif extension_name == b"arrow.bool8": + out = Bool8Type.__new__(Bool8Type) + elif extension_name == b"arrow.fixed_shape_tensor": out = FixedShapeTensorType.__new__(FixedShapeTensorType) - elif ext_type.extension_name() == b"arrow.opaque": + elif extension_name == b"arrow.opaque": out = OpaqueType.__new__(OpaqueType) - elif ext_type.extension_name() == b"arrow.bool8": - out = Bool8Type.__new__(Bool8Type) + elif extension_name == b"arrow.uuid": + out = UuidType.__new__(UuidType) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 72ae2aee5f8b3..68f77832c4342 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -17,6 +17,7 @@ import collections from cython cimport binding +from uuid import UUID cdef class Scalar(_Weakrefable): @@ -1043,6 +1044,15 @@ cdef class ExtensionScalar(Scalar): return pyarrow_wrap_scalar( sp_scalar) +class UuidScalar(ExtensionScalar): + """ + Concrete class for Uuid extension scalar. + """ + + def as_py(self): + return None if self.value is None else UUID(bytes=self.value.as_py()) + + cdef class FixedShapeTensorScalar(ExtensionScalar): """ Concrete class for fixed shape tensor extension scalar. diff --git a/python/pyarrow/src/arrow/python/gdb.cc b/python/pyarrow/src/arrow/python/gdb.cc index 6941769e4efe8..7c58bae3342c2 100644 --- a/python/pyarrow/src/arrow/python/gdb.cc +++ b/python/pyarrow/src/arrow/python/gdb.cc @@ -22,7 +22,7 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" #include "arrow/datum.h" -#include "arrow/extension_type.h" +#include "arrow/extension/uuid.h" #include "arrow/ipc/json_simple.h" #include "arrow/python/gdb.h" #include "arrow/record_batch.h" @@ -37,6 +37,8 @@ namespace arrow { +using extension::uuid; +using extension::UuidType; using ipc::internal::json::ArrayFromJSON; using ipc::internal::json::ChunkedArrayFromJSON; using ipc::internal::json::ScalarFromJSON; @@ -56,29 +58,6 @@ class CustomStatusDetail : public StatusDetail { std::string ToString() const override { return "This is a detail"; } }; -class UuidType : public ExtensionType { - public: - UuidType() : ExtensionType(fixed_size_binary(16)) {} - - std::string extension_name() const override { return "uuid"; } - - bool ExtensionEquals(const ExtensionType& other) const override { - return (other.extension_name() == this->extension_name()); - } - - std::shared_ptr MakeArray(std::shared_ptr data) const override { - return std::make_shared(data); - } - - Result> Deserialize( - std::shared_ptr storage_type, - const std::string& serialized) const override { - return Status::NotImplemented(""); - } - - std::string Serialize() const override { return "uuid-serialized"; } -}; - std::shared_ptr SliceArrayFromJSON(const std::shared_ptr& ty, std::string_view json, int64_t offset = 0, int64_t length = -1) { diff --git a/python/pyarrow/tests/extensions.pyx b/python/pyarrow/tests/extensions.pyx index c1bf9aae1ec03..309b574dc0264 100644 --- a/python/pyarrow/tests/extensions.pyx +++ b/python/pyarrow/tests/extensions.pyx @@ -37,7 +37,7 @@ cdef extern from * namespace "arrow::py" nogil: class UuidType : public ExtensionType { public: UuidType() : ExtensionType(fixed_size_binary(16)) {} - std::string extension_name() const override { return "uuid"; } + std::string extension_name() const override { return "example-uuid"; } bool ExtensionEquals(const ExtensionType& other) const override { return other.extension_name() == this->extension_name(); diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 0d50c467e96bd..aacbd2cb6e756 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -95,18 +95,21 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return cls() -class UuidScalarType(pa.ExtensionScalar): +class ExampleUuidScalarType(pa.ExtensionScalar): def as_py(self): return None if self.value is None else UUID(bytes=self.value.as_py()) -class UuidType(pa.ExtensionType): +class ExampleUuidType(pa.ExtensionType): def __init__(self): - super().__init__(pa.binary(16), 'pyarrow.tests.UuidType') + super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType') + + def __reduce__(self): + return ExampleUuidType, () def __arrow_ext_scalar_class__(self): - return UuidScalarType + return ExampleUuidScalarType def __arrow_ext_serialize__(self): return b'' @@ -116,10 +119,10 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return cls() -class UuidType2(pa.ExtensionType): +class ExampleUuidType2(pa.ExtensionType): def __init__(self): - super().__init__(pa.binary(16), 'pyarrow.tests.UuidType2') + super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType2') def __arrow_ext_serialize__(self): return b'' @@ -250,8 +253,8 @@ def ipc_read_batch(buf): def test_ext_type_basics(): - ty = UuidType() - assert ty.extension_name == "pyarrow.tests.UuidType" + ty = ExampleUuidType() + assert ty.extension_name == "pyarrow.tests.ExampleUuidType" def test_ext_type_str(): @@ -267,16 +270,16 @@ def test_ext_type_repr(): def test_ext_type_lifetime(): - ty = UuidType() + ty = ExampleUuidType() wr = weakref.ref(ty) del ty assert wr() is None def test_ext_type_storage_type(): - ty = UuidType() + ty = ExampleUuidType() assert ty.storage_type == pa.binary(16) - assert ty.__class__ is UuidType + assert ty.__class__ is ExampleUuidType ty = ParamExtType(5) assert ty.storage_type == pa.binary(5) assert ty.__class__ is ParamExtType @@ -284,7 +287,7 @@ def test_ext_type_storage_type(): def test_ext_type_byte_width(): # Test for fixed-size binary types - ty = UuidType() + ty = pa.uuid() assert ty.byte_width == 16 ty = ParamExtType(5) assert ty.byte_width == 5 @@ -297,7 +300,7 @@ def test_ext_type_byte_width(): def test_ext_type_bit_width(): # Test for fixed-size binary types - ty = UuidType() + ty = pa.uuid() assert ty.bit_width == 128 ty = ParamExtType(5) assert ty.bit_width == 40 @@ -309,7 +312,7 @@ def test_ext_type_bit_width(): def test_ext_type_as_py(): - ty = UuidType() + ty = ExampleUuidType() expected = uuid4() scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes) assert scalar.as_py() == expected @@ -342,12 +345,22 @@ def test_ext_type_as_py(): def test_uuid_type_pickle(pickle_module): for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1): - ty = UuidType() + ty = ExampleUuidType() ser = pickle_module.dumps(ty, protocol=proto) del ty ty = pickle_module.loads(ser) wr = weakref.ref(ty) - assert ty.extension_name == "pyarrow.tests.UuidType" + assert ty.extension_name == "pyarrow.tests.ExampleUuidType" + del ty + assert wr() is None + + for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1): + ty = pa.uuid() + ser = pickle_module.dumps(ty, protocol=proto) + del ty + ty = pickle_module.loads(ser) + wr = weakref.ref(ty) + assert ty.extension_name == "arrow.uuid" del ty assert wr() is None @@ -358,8 +371,8 @@ def test_ext_type_equality(): c = ParamExtType(6) assert a != b assert b == c - d = UuidType() - e = UuidType() + d = ExampleUuidType() + e = ExampleUuidType() assert a != d assert d == e @@ -403,7 +416,7 @@ def test_ext_array_equality(): storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) storage3 = pa.array([], type=pa.binary(16)) - ty1 = UuidType() + ty1 = ExampleUuidType() ty2 = ParamExtType(16) a = pa.ExtensionArray.from_storage(ty1, storage1) @@ -451,9 +464,9 @@ def test_ext_scalar_from_array(): data = [b"0123456789abcdef", b"0123456789abcdef", b"zyxwvutsrqponmlk", None] storage = pa.array(data, type=pa.binary(16)) - ty1 = UuidType() + ty1 = ExampleUuidType() ty2 = ParamExtType(16) - ty3 = UuidType2() + ty3 = ExampleUuidType2() a = pa.ExtensionArray.from_storage(ty1, storage) b = pa.ExtensionArray.from_storage(ty2, storage) @@ -462,9 +475,9 @@ def test_ext_scalar_from_array(): scalars_a = list(a) assert len(scalars_a) == 4 - assert ty1.__arrow_ext_scalar_class__() == UuidScalarType - assert isinstance(a[0], UuidScalarType) - assert isinstance(scalars_a[0], UuidScalarType) + assert ty1.__arrow_ext_scalar_class__() == ExampleUuidScalarType + assert isinstance(a[0], ExampleUuidScalarType) + assert isinstance(scalars_a[0], ExampleUuidScalarType) for s, val in zip(scalars_a, data): assert isinstance(s, pa.ExtensionScalar) @@ -505,7 +518,7 @@ def test_ext_scalar_from_array(): def test_ext_scalar_from_storage(): - ty = UuidType() + ty = ExampleUuidType() s = pa.ExtensionScalar.from_storage(ty, None) assert isinstance(s, pa.ExtensionScalar) @@ -706,14 +719,14 @@ def test_cast_between_extension_types(): tiny_int_arr.cast(pa.int64()).cast(IntegerType()) # Between the same extension types is okay - array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(UuidType()) - out = array.cast(UuidType()) - assert out.type == UuidType() + array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(ExampleUuidType()) + out = array.cast(ExampleUuidType()) + assert out.type == ExampleUuidType() # Will still fail casting between extensions who share storage type, # can only cast between exactly the same extension types. with pytest.raises(TypeError, match='Casting from *'): - array.cast(UuidType2()) + array.cast(ExampleUuidType2()) def test_cast_to_extension_with_extension_storage(): @@ -744,10 +757,10 @@ def test_cast_nested_extension_types(data, type_factory): def test_casting_dict_array_to_extension_type(): storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) - arr = pa.ExtensionArray.from_storage(UuidType(), storage) + arr = pa.ExtensionArray.from_storage(ExampleUuidType(), storage) dict_arr = pa.DictionaryArray.from_arrays(pa.array([0, 0], pa.int32()), arr) - out = dict_arr.cast(UuidType()) + out = dict_arr.cast(ExampleUuidType()) assert isinstance(out, pa.ExtensionArray) assert out.to_pylist() == [UUID('30313233-3435-3637-3839-616263646566'), UUID('30313233-3435-3637-3839-616263646566')] @@ -1347,7 +1360,7 @@ def test_cpp_extension_in_python(tmpdir): mod = __import__('extensions') uuid_type = mod._make_uuid_type() - assert uuid_type.extension_name == "uuid" + assert uuid_type.extension_name == "example-uuid" assert uuid_type.storage_type == pa.binary(16) array = mod._make_uuid_array() @@ -1356,6 +1369,31 @@ def test_cpp_extension_in_python(tmpdir): assert array[0].as_py() == b'abcdefghijklmno0' assert array[1].as_py() == b'0onmlkjihgfedcba' + buf = ipc_write_batch(pa.RecordBatch.from_arrays([array], ["example-uuid"])) + + batch = ipc_read_batch(buf) + reconstructed_array = batch.column(0) + assert reconstructed_array.type == uuid_type + assert reconstructed_array == array + + +def test_uuid_extension(): + data = [b"0123456789abcdef", b"0123456789abcdef", + b"zyxwvutsrqponmlk", None] + + uuid_type = pa.uuid() + assert uuid_type.extension_name == "arrow.uuid" + assert uuid_type.storage_type == pa.binary(16) + assert uuid_type.__class__ is pa.UuidType + + storage = pa.array(data, pa.binary(16)) + array = pa.ExtensionArray.from_storage(uuid_type, storage) + assert array.type == uuid_type + + assert array.to_pylist() == [x if x is None else UUID(bytes=x) for x in data] + assert array[0].as_py() == UUID(bytes=data[0]) + assert array[3].as_py() is None + buf = ipc_write_batch(pa.RecordBatch.from_arrays([array], ["uuid"])) batch = ipc_read_batch(buf) @@ -1363,6 +1401,9 @@ def test_cpp_extension_in_python(tmpdir): assert reconstructed_array.type == uuid_type assert reconstructed_array == array + assert uuid_type.__arrow_ext_scalar_class__() == pa.UuidScalar + assert isinstance(array[0], pa.UuidScalar) + def test_tensor_type(): tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3]) diff --git a/python/pyarrow/tests/test_gdb.py b/python/pyarrow/tests/test_gdb.py index 0d12d710dcf64..2ac2f55754fe5 100644 --- a/python/pyarrow/tests/test_gdb.py +++ b/python/pyarrow/tests/test_gdb.py @@ -409,7 +409,7 @@ def test_types_stack(gdb_arrow): check_stack_repr( gdb_arrow, "uuid_type", - ('arrow::ExtensionType "extension" ' + ('arrow::ExtensionType "extension" ' 'with storage type arrow::fixed_size_binary(16)')) @@ -447,7 +447,7 @@ def test_types_heap(gdb_arrow): check_heap_repr( gdb_arrow, "heap_uuid_type", - ('arrow::ExtensionType "extension" ' + ('arrow::ExtensionType "extension" ' 'with storage type arrow::fixed_size_binary(16)')) @@ -716,12 +716,12 @@ def test_scalars_stack(gdb_arrow): check_stack_repr( gdb_arrow, "extension_scalar", - ('arrow::ExtensionScalar of type "extension", ' + ('arrow::ExtensionScalar of type "extension", ' 'value arrow::FixedSizeBinaryScalar of size 16, ' 'value "0123456789abcdef"')) check_stack_repr( gdb_arrow, "extension_scalar_null", - 'arrow::ExtensionScalar of type "extension", null value') + 'arrow::ExtensionScalar of type "extension", null value') def test_scalars_heap(gdb_arrow): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 563782f0c2643..f83ecc3aa4326 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1765,6 +1765,25 @@ cdef class ExtensionType(BaseExtensionType): return ExtensionScalar +cdef class UuidType(BaseExtensionType): + """ + Concrete class for UUID extension type. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.uuid_ext_type = type.get() + + def __arrow_ext_class__(self): + return UuidArray + + def __reduce__(self): + return uuid, () + + def __arrow_ext_scalar_class__(self): + return UuidScalar + + cdef class FixedShapeTensorType(BaseExtensionType): """ Concrete class for fixed shape tensor extension type. @@ -5208,6 +5227,21 @@ def run_end_encoded(run_end_type, value_type): return pyarrow_wrap_data_type(ree_type) +def uuid(): + """ + Create UuidType instance. + + Returns + ------- + type : UuidType + """ + + cdef UuidType out = UuidType.__new__(UuidType) + c_uuid_ext_type = GetResultValue(CUuidType.Make()) + out.init(c_uuid_ext_type) + return out + + def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=None): """ Create instance of fixed shape tensor extension type with shape and optional From 8eb7bd4115da0027aad6362f0fe0901ec44b0616 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 09:12:57 +0900 Subject: [PATCH 30/32] MINOR: [Go] Bump github.com/hamba/avro/v2 from 2.24.1 to 2.25.0 in /go (#43829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github.com/hamba/avro/v2](https://github.com/hamba/avro) from 2.24.1 to 2.25.0.
Release notes

Sourced from github.com/hamba/avro/v2's releases.

v2.25.0

What's Changed

New Contributors

Full Changelog: https://github.com/hamba/avro/compare/v2.24.1...v2.24.2

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/hamba/avro/v2&package-manager=go_modules&previous-version=2.24.1&new-version=2.25.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 9f4222a541bb6..97ac05685970c 100644 --- a/go/go.mod +++ b/go/go.mod @@ -47,7 +47,7 @@ require ( require ( github.com/google/uuid v1.6.0 - github.com/hamba/avro/v2 v2.24.1 + github.com/hamba/avro/v2 v2.25.0 github.com/huandu/xstrings v1.4.0 github.com/substrait-io/substrait-go v0.6.0 github.com/tidwall/sjson v1.2.5 diff --git a/go/go.sum b/go/go.sum index c7eb3a66deeec..bd761e1589453 100644 --- a/go/go.sum +++ b/go/go.sum @@ -43,8 +43,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hamba/avro/v2 v2.24.1 h1:Xi+7AnhaAc41aA/jmmYpxMsdEDOf1rdup6NJ85P7q2I= -github.com/hamba/avro/v2 v2.24.1/go.mod h1:7vDfy/2+kYCE8WUHoj2et59GTv0ap7ptktMXu0QHePI= +github.com/hamba/avro/v2 v2.25.0 h1:9qig/K4VP5tMq6DuKGfI6YdXncTkPJT1IJDMSv82EeI= +github.com/hamba/avro/v2 v2.25.0/go.mod h1:I8glyswHnpED3Nlx2ZdUe+4LJnCOOyiCzLMno9i/Uu0= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= From 93c5ddb957bb93421a8f84dbd7c5a5b7be2d6d45 Mon Sep 17 00:00:00 2001 From: PANKAJ9768 <48675737+PANKAJ9768@users.noreply.github.com> Date: Tue, 27 Aug 2024 05:59:09 +0530 Subject: [PATCH 31/32] GH-43667: [Java] Keeping Flight default header size consistent between server and client (#43697) ### Rationale for this change ### What changes are included in this PR? Flight client can send header size larger than server can accept. This PR is to keep default values consistent across server and client. ### Are these changes tested? ### Are there any user-facing changes? * GitHub Issue: #43667 Authored-by: pankaj kesari Signed-off-by: David Li --- .../org/apache/arrow/flight/FlightServer.java | 7 ++ .../arrow/flight/TestFlightService.java | 73 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index 05dbe42c49172..ac761457f57fd 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -188,6 +188,7 @@ public static final class Builder { private CallHeaderAuthenticator headerAuthenticator = CallHeaderAuthenticator.NO_OP; private ExecutorService executor = null; private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE; + private int maxHeaderListSize = MAX_GRPC_MESSAGE_SIZE; private int backpressureThreshold = DEFAULT_BACKPRESSURE_THRESHOLD; private InputStream certChain; private InputStream key; @@ -324,6 +325,7 @@ public FlightServer build() { builder .executor(exec) .maxInboundMessageSize(maxInboundMessageSize) + .maxInboundMetadataSize(maxHeaderListSize) .addService( ServerInterceptors.intercept( flightService, @@ -366,6 +368,11 @@ public FlightServer build() { return new FlightServer(location, builder.build(), grpcExecutor); } + public Builder setMaxHeaderListSize(int maxHeaderListSize) { + this.maxHeaderListSize = maxHeaderListSize; + return this; + } + /** * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying * transport. diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index 5ebeb44c1d36e..fc3f83e4eafd3 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Optional; +import java.util.Random; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -152,4 +153,76 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor assertEquals("No schema is present in FlightInfo", e.getMessage()); } } + + /** + * Test for GH-41584 where flight defaults for header size was not in sync b\w client and server. + */ + @Test + public void testHeaderSizeExchangeInService() throws Exception { + final FlightProducer producer = + new NoOpFlightProducer() { + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + String longHeader = + context.getMiddleware(FlightConstants.HEADER_KEY).headers().get("long-header"); + return new FlightInfo( + null, + descriptor, + Collections.emptyList(), + 0, + 0, + false, + IpcOption.DEFAULT, + longHeader.getBytes(StandardCharsets.UTF_8)); + } + }; + + String headerVal = generateRandom(1024 * 10); + FlightCallHeaders callHeaders = new FlightCallHeaders(); + callHeaders.insert("long-header", headerVal); + // sever with default header limit same as client + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer) + .build() + .start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders)); + assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8)); + } + // server with 15kb header limit + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer) + .setMaxHeaderListSize(1024 * 15) + .build() + .start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders)); + assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8)); + + callHeaders.insert("another-header", headerVal + headerVal); + FlightRuntimeException e = + assertThrows( + FlightRuntimeException.class, + () -> + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders))); + assertEquals("http2 exception", e.getMessage()); + } + } + + private static String generateRandom(int size) { + String aToZ = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + Random random = new Random(); + StringBuilder res = new StringBuilder(); + for (int i = 0; i < size; i++) { + int randIndex = random.nextInt(aToZ.length()); + res.append(aToZ.charAt(randIndex)); + } + return res.toString(); + } } From 11f92491b1d2ecf700e6e023a1e413ec4c4345ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:06:13 +0900 Subject: [PATCH 32/32] MINOR: [Go] Bump github.com/substrait-io/substrait-go from 0.6.0 to 0.7.0 in /go (#43830) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github.com/substrait-io/substrait-go](https://github.com/substrait-io/substrait-go) from 0.6.0 to 0.7.0.
Release notes

Sourced from github.com/substrait-io/substrait-go's releases.

v0.7.0 (2024-08-25)

Features

  • Add convenience literal APIs (#47) (597afdb)
    • Introduce literal package

Changes to the build process or auxiliary tools and libraries such as documentation generation

  • extensions Minor refactoring in extension_mgr.go (#45) (cbd28cb)
    • Minor refactoring in extension_mgr.go
  • Move typeName maps to types package (#46) (5556c23)
Commits
  • 597afdb feat: Add convenience literal APIs (#47)
  • e77df67 feat(types) Make time precision value explicit (#49)
  • a3e8ee0 feat(substrait) Update to substrait v0.55.0 (#48)
  • 2229c12 ci(build-test): golangci should use the go.mod version of golang (#51)
  • cbd28cb chore(extensions): Minor refactoring in extension_mgr.go (#45)
  • 5556c23 chore: Move typeName maps to types package (#46)
  • dd790cb Add a function registry for a given BFT dialect (#32)
  • 828636c ci(build-test): Add golangci-lint to do import checking and other linting (#42)
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/substrait-io/substrait-go&package-manager=go_modules&previous-version=0.6.0&new-version=0.7.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 97ac05685970c..a995eee24d563 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.25.0 github.com/huandu/xstrings v1.4.0 - github.com/substrait-io/substrait-go v0.6.0 + github.com/substrait-io/substrait-go v0.7.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go/go.sum b/go/go.sum index bd761e1589453..6f22e11aef03a 100644 --- a/go/go.sum +++ b/go/go.sum @@ -99,8 +99,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/substrait-io/substrait-go v0.6.0 h1:n2G/SGmrn7U5Q39VA8WeM2UfVL5Y/6HX8WAP9uJLNk4= -github.com/substrait-io/substrait-go v0.6.0/go.mod h1:cl8Wsc7aBPDfcHp9+OrUqGpjkgrYlhcDsH/lMP6KUZA= +github.com/substrait-io/substrait-go v0.7.0 h1:53yi73t4wW383+RD1YuhXhbjhP1KzF9GCxPC7SsRlqc= +github.com/substrait-io/substrait-go v0.7.0/go.mod h1:7mjSvIaxk94bOF+YZn/vBOpHK4DWTpBv7nC/btjXCmc= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=