Skip to content

Commit 6f71f77

Browse files
committed
Update Python bindings, refine error types
1 parent d2f70b5 commit 6f71f77

File tree

16 files changed

+126
-145
lines changed

16 files changed

+126
-145
lines changed

cpp/src/arrow/compute/api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
#include "arrow/compute/function.h" // IWYU pragma: export
2828
#include "arrow/compute/kernel.h" // IWYU pragma: export
2929
#include "arrow/compute/registry.h" // IWYU pragma: export
30+
#include "arrow/datum.h" // IWYU pragma: export

cpp/src/arrow/compute/cast.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ Result<const ScalarKernel*> CastFunction::DispatchExact(
129129
}
130130

131131
if (candidate_kernels.size() == 0) {
132-
return Status::Invalid("Function ", this->name(),
133-
" has no kernel matching input type ", values[0].ToString());
132+
return Status::NotImplemented("Function ", this->name(),
133+
" has no kernel matching input type ",
134+
values[0].ToString());
134135
} else if (candidate_kernels.size() == 1) {
135136
// One match, return it
136137
return candidate_kernels[0];
@@ -173,7 +174,8 @@ Result<std::shared_ptr<const CastFunction>> GetCastFunction(
173174
internal::EnsureInitCastTable();
174175
auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
175176
if (it == internal::g_cast_table.end()) {
176-
return Status::Invalid("No cast function available to cast to ", to_type->ToString());
177+
return Status::NotImplemented("No cast function available to cast to ",
178+
to_type->ToString());
177179
}
178180
return it->second;
179181
}

cpp/src/arrow/compute/kernels/scalar_cast_internal.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
6363
}
6464
}
6565

66-
void AddZeroCopyCast(InputType in_type, const std::shared_ptr<DataType>& out_type,
67-
CastFunction* func) {
66+
void AddZeroCopyCast(InputType in_type, OutputType out_type, CastFunction* func) {
6867
auto sig = KernelSignature::Make({in_type}, out_type);
6968
ScalarKernel kernel;
7069
kernel.exec = ZeroCopyCastExec;

cpp/src/arrow/compute/kernels/scalar_cast_internal.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,12 @@ struct FromDictionaryCast {
172172
const Array& dictionary = *input.dictionary;
173173
const DataType& values_type = *dictionary.type();
174174

175-
// Check if values and output type match
176-
DCHECK(values_type.Equals(*output->type))
177-
<< "Dictionary type: " << values_type << " target type: " << (*output->type);
175+
// ARROW-7077
176+
if (!values_type.Equals(*output->type)) {
177+
ctx->SetStatus(Status::Invalid("Cannot unpack dictionary of type ", type.ToString(),
178+
" to type ", output->type->ToString()));
179+
return;
180+
}
178181

179182
FromDictUnpackHelper<T> unpack_helper;
180183
switch (type.index_type()->id()) {
@@ -237,8 +240,7 @@ void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) {
237240

238241
void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out);
239242

240-
void AddZeroCopyCast(InputType in_type, const std::shared_ptr<DataType>& out_type,
241-
CastFunction* func);
243+
void AddZeroCopyCast(InputType in_type, OutputType out_type, CastFunction* func);
242244

243245
// OutputType::Resolver that returns a descr with the shape of the input
244246
// argument and the type from CastOptions
@@ -252,7 +254,8 @@ struct MaybeAddFromDictionary {
252254

253255
template <typename T>
254256
struct MaybeAddFromDictionary<
255-
T, enable_if_t<!is_boolean_type<T>::value && !is_nested_type<T>::value>> {
257+
T, enable_if_t<!is_boolean_type<T>::value && !is_nested_type<T>::value &&
258+
!std::is_same<DictionaryType, T>::value>> {
256259
static void Add(const OutputType& out_ty, CastFunction* func) {
257260
// Dictionary unpacking not implemented for boolean or nested types
258261
DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)},

cpp/src/arrow/compute/kernels/scalar_cast_nested.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,16 @@ std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() {
7676
AddCommonCasts<LargeListType>(kOutputTargetType, cast_large_list.get());
7777
AddListCast<LargeListType>(cast_large_list.get());
7878

79-
return {cast_list, cast_large_list};
79+
// FSL is a bit incomplete at the moment
80+
auto cast_fsl =
81+
std::make_shared<CastFunction>("cast_fixed_size_list", Type::FIXED_SIZE_LIST);
82+
AddCommonCasts<FixedSizeListType>(kOutputTargetType, cast_fsl.get());
83+
84+
// So is struct
85+
auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT);
86+
AddCommonCasts<StructType>(kOutputTargetType, cast_struct.get());
87+
88+
return {cast_list, cast_large_list, cast_fsl, cast_struct};
8089
}
8190

8291
} // namespace internal

cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,12 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
546546
functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32"));
547547
functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));
548548

549+
// HalfFloat is a bit brain-damaged for now
550+
auto cast_half_float =
551+
std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
552+
AddCommonCasts<HalfFloatType>(float16(), cast_half_float.get());
553+
functions.push_back(cast_half_float);
554+
549555
functions.push_back(GetCastToFloating<FloatType>("cast_float"));
550556
functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
551557

cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ struct CastFunctor<
134134
const auto& in_type = checked_cast<const I&>(*batch[0].type());
135135
const auto& out_type = checked_cast<const O&>(*output->type);
136136

137-
DCHECK_NE(in_type.unit(), out_type.unit()) << "Do not cast equal types";
137+
// The units may be equal if the time zones are different. We might go to
138+
// lengths to make this zero copy in the future but we leave it for now
138139

139140
auto conversion = util::kTimestampConversionTable[static_cast<int>(in_type.unit())]
140141
[static_cast<int>(out_type.unit())];
@@ -344,10 +345,7 @@ std::shared_ptr<CastFunction> GetDurationCast() {
344345
auto nanos = duration(TimeUnit::NANO);
345346

346347
// Same integer representation
347-
AddZeroCopyCast(/*in_type=*/int64(), /*out_type=*/seconds, func.get());
348-
AddZeroCopyCast(int64(), millis, func.get());
349-
AddZeroCopyCast(int64(), micros, func.get());
350-
AddZeroCopyCast(int64(), nanos, func.get());
348+
AddZeroCopyCast(/*in_type=*/int64(), kOutputTargetType, func.get());
351349

352350
// Between durations
353351
AddCrossUnitCast<DurationType>(func.get());
@@ -359,12 +357,8 @@ std::shared_ptr<CastFunction> GetTime32Cast() {
359357
auto func = std::make_shared<CastFunction>("cast_time32", Type::TIME32);
360358
AddCommonCasts<Date32Type>(kOutputTargetType, func.get());
361359

362-
auto seconds = time32(TimeUnit::SECOND);
363-
auto millis = time32(TimeUnit::MILLI);
364-
365360
// Zero copy when the unit is the same or same integer representation
366-
AddZeroCopyCast(int32(), seconds, func.get());
367-
AddZeroCopyCast(int32(), millis, func.get());
361+
AddZeroCopyCast(/*in_type=*/int32(), kOutputTargetType, func.get());
368362

369363
// time64 -> time32
370364
AddSimpleCast<Time64Type, Time32Type>(InputType(Type::TIME64), kOutputTargetType,
@@ -380,12 +374,8 @@ std::shared_ptr<CastFunction> GetTime64Cast() {
380374
auto func = std::make_shared<CastFunction>("cast_time64", Type::TIME64);
381375
AddCommonCasts<Time64Type>(kOutputTargetType, func.get());
382376

383-
auto micros = time64(TimeUnit::MICRO);
384-
auto nanos = time64(TimeUnit::NANO);
385-
386377
// Zero copy when the unit is the same or same integer representation
387-
AddZeroCopyCast(int64(), micros, func.get());
388-
AddZeroCopyCast(int64(), nanos, func.get());
378+
AddZeroCopyCast(/*in_type=*/int64(), kOutputTargetType, func.get());
389379

390380
// time32 -> time64
391381
AddSimpleCast<Time32Type, Time64Type>(InputType(Type::TIME32), kOutputTargetType,
@@ -402,11 +392,7 @@ std::shared_ptr<CastFunction> GetTimestampCast() {
402392
AddCommonCasts<TimestampType>(kOutputTargetType, func.get());
403393

404394
// Same integer representation
405-
AddZeroCopyCast(/*in_type=*/int64(), /*out_type=*/timestamp(TimeUnit::SECOND),
406-
func.get());
407-
AddZeroCopyCast(int64(), timestamp(TimeUnit::MILLI), func.get());
408-
AddZeroCopyCast(int64(), timestamp(TimeUnit::MICRO), func.get());
409-
AddZeroCopyCast(int64(), timestamp(TimeUnit::NANO), func.get());
395+
AddZeroCopyCast(/*in_type=*/int64(), kOutputTargetType, func.get());
410396

411397
// From date types
412398
// TODO: ARROW-8876, these casts are not implemented

cpp/src/arrow/compute/kernels/scalar_cast_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ TEST_F(TestCast, UnsupportedTarget) {
12161216
std::shared_ptr<Array> arr;
12171217
ArrayFromVector<Int32Type, int32_t>(int32(), is_valid, v1, &arr);
12181218

1219-
ASSERT_RAISES(Invalid, Cast(*arr, list(utf8())));
1219+
ASSERT_RAISES(NotImplemented, Cast(*arr, list(utf8())));
12201220
}
12211221

12221222
TEST_F(TestCast, DateTimeZeroCopy) {
@@ -1322,8 +1322,8 @@ TEST_F(TestCast, ListToPrimitive) {
13221322
auto from_int = ArrayFromJSON(list(int8()), "[[1, 2], [3, 4]]");
13231323
auto from_binary = ArrayFromJSON(list(binary()), "[[\"1\", \"2\"], [\"3\", \"4\"]]");
13241324

1325-
ASSERT_RAISES(Invalid, Cast(*from_int, uint8()));
1326-
ASSERT_RAISES(Invalid, Cast(*from_binary, utf8()));
1325+
ASSERT_RAISES(NotImplemented, Cast(*from_int, uint8()));
1326+
ASSERT_RAISES(NotImplemented, Cast(*from_binary, utf8()));
13271327
}
13281328

13291329
TEST_F(TestCast, ListToList) {
@@ -1530,7 +1530,7 @@ TYPED_TEST(TestDictionaryCast, DISABLED_OutTypeError) {
15301530
auto out_type = (plain_array->type()->id() == Type::INT8) ? binary() : int8();
15311531
// Test an output type that's not part of TestTypes.
15321532
out_type = list(in_type);
1533-
ASSERT_RAISES(Invalid, GetCastFunction(out_type));
1533+
ASSERT_RAISES(NotImplemented, GetCastFunction(out_type));
15341534
}
15351535

15361536
std::shared_ptr<Array> SmallintArrayFromJSON(const std::string& json_data) {

python/pyarrow/_compute.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020
from pyarrow.lib cimport (
2121
Array,
2222
wrap_datum,
23-
_context,
2423
check_status,
2524
ChunkedArray
2625
)
2726
from pyarrow.includes.libarrow cimport CDatum, Sum
27+
from pyarrow.includes.common cimport *
2828

2929

3030
cdef _sum_array(array: Array):
3131
cdef CDatum out
3232

3333
with nogil:
34-
check_status(Sum(_context(), CDatum(array.sp_array), &out))
34+
out = GetResultValue(Sum(CDatum(array.sp_array)))
3535

3636
return wrap_datum(out)
3737

@@ -40,7 +40,7 @@ cdef _sum_chunked_array(array: ChunkedArray):
4040
cdef CDatum out
4141

4242
with nogil:
43-
check_status(Sum(_context(), CDatum(array.sp_chunked_array), &out))
43+
out = GetResultValue(Sum(CDatum(array.sp_chunked_array)))
4444

4545
return wrap_datum(out)
4646

python/pyarrow/array.pxi

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -376,19 +376,6 @@ cdef Py_ssize_t _normalize_index(Py_ssize_t index,
376376
return index
377377

378378

379-
cdef class _FunctionContext:
380-
cdef:
381-
unique_ptr[CFunctionContext] ctx
382-
383-
def __cinit__(self):
384-
self.ctx.reset(new CFunctionContext(c_default_memory_pool()))
385-
386-
cdef _FunctionContext _global_ctx = _FunctionContext()
387-
388-
cdef CFunctionContext* _context() nogil:
389-
return _global_ctx.ctx.get()
390-
391-
392379
cdef wrap_datum(const CDatum& datum):
393380
if datum.kind() == DatumType_ARRAY:
394381
return pyarrow_wrap_array(MakeArray(datum.array()))
@@ -705,9 +692,7 @@ cdef class Array(_PandasConvertible):
705692
shared_ptr[CArray] result
706693

707694
with nogil:
708-
check_status(Cast(_context(), self.ap[0], type.sp_type,
709-
options, &result))
710-
695+
result = GetResultValue(Cast(self.ap[0], type.sp_type, options))
711696
return pyarrow_wrap_array(result)
712697

713698
def view(self, object target_type):
@@ -736,10 +721,8 @@ cdef class Array(_PandasConvertible):
736721
Sum the values in a numerical array.
737722
"""
738723
cdef CDatum out
739-
740724
with nogil:
741-
check_status(Sum(_context(), CDatum(self.sp_array), &out))
742-
725+
out = GetResultValue(Sum(CDatum(self.sp_array)))
743726
return wrap_datum(out)
744727

745728
def unique(self):
@@ -749,7 +732,7 @@ cdef class Array(_PandasConvertible):
749732
cdef shared_ptr[CArray] result
750733

751734
with nogil:
752-
check_status(Unique(_context(), CDatum(self.sp_array), &result))
735+
result = GetResultValue(Unique(CDatum(self.sp_array)))
753736

754737
return pyarrow_wrap_array(result)
755738

@@ -760,9 +743,7 @@ cdef class Array(_PandasConvertible):
760743
cdef CDatum out
761744

762745
with nogil:
763-
check_status(DictionaryEncode(_context(), CDatum(self.sp_array),
764-
&out))
765-
746+
out = GetResultValue(DictionaryEncode(CDatum(self.sp_array)))
766747
return wrap_datum(out)
767748

768749
def value_counts(self):
@@ -776,8 +757,7 @@ cdef class Array(_PandasConvertible):
776757
cdef shared_ptr[CArray] result
777758

778759
with nogil:
779-
check_status(ValueCounts(_context(), CDatum(self.sp_array),
780-
&result))
760+
result = GetResultValue(ValueCounts(CDatum(self.sp_array)))
781761
return pyarrow_wrap_array(result)
782762

783763
@staticmethod
@@ -1040,8 +1020,8 @@ cdef class Array(_PandasConvertible):
10401020
c_indices = asarray(indices)
10411021

10421022
with nogil:
1043-
check_status(Take(_context(), CDatum(self.sp_array),
1044-
CDatum(c_indices.sp_array), options, &out))
1023+
out = GetResultValue(Take(CDatum(self.sp_array),
1024+
CDatum(c_indices.sp_array), options))
10451025

10461026
return wrap_datum(out)
10471027

@@ -1091,8 +1071,8 @@ cdef class Array(_PandasConvertible):
10911071
options = _convert_filter_option(null_selection_behavior)
10921072

10931073
with nogil:
1094-
check_status(FilterKernel(_context(), CDatum(self.sp_array),
1095-
CDatum(mask.sp_array), options, &out))
1074+
out = GetResultValue(FilterKernel(CDatum(self.sp_array),
1075+
CDatum(mask.sp_array), options))
10961076

10971077
return wrap_datum(out)
10981078

0 commit comments

Comments
 (0)