Skip to content

Commit d922bd2

Browse files
committed
Use new Scalar objects in aggregation code
Change-Id: I87f690898c7af7b028f2e50247e96703ede4f457
1 parent fa89bd0 commit d922bd2

File tree

12 files changed

+193
-182
lines changed

12 files changed

+193
-182
lines changed

cpp/build-support/run_cpplint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def _check_some_files(completed_processes, filenames):
112112
if problem_files:
113113
msg = "{} had cpplint issues"
114114
print("\n".join(map(msg.format, problem_files)))
115+
if isinstance(stdout, bytes):
116+
stdout = stdout.decode('utf8')
115117
print(stdout, file=sys.stderr)
116118
error = True
117119
except Exception:

cpp/src/arrow/compute/compute-test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@ void CheckImplicitConstructor(enum Datum::type expected_kind) {
5959
}
6060

6161
TEST(TestDatum, ImplicitConstructors) {
62+
CheckImplicitConstructor<Scalar>(Datum::SCALAR);
63+
6264
CheckImplicitConstructor<Array>(Datum::ARRAY);
6365

6466
// Instantiate from array subclass
6567
CheckImplicitConstructor<BinaryArray>(Datum::ARRAY);
6668

6769
CheckImplicitConstructor<ChunkedArray>(Datum::CHUNKED_ARRAY);
6870
CheckImplicitConstructor<RecordBatch>(Datum::RECORD_BATCH);
71+
6972
CheckImplicitConstructor<Table>(Datum::TABLE);
7073
}
7174

cpp/src/arrow/compute/kernel.h

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "arrow/array.h"
2626
#include "arrow/record_batch.h"
27+
#include "arrow/scalar.h"
2728
#include "arrow/table.h"
2829
#include "arrow/util/macros.h"
2930
#include "arrow/util/variant.h" // IWYU pragma: export
@@ -55,68 +56,20 @@ class ARROW_EXPORT OpKernel {
5556
virtual ~OpKernel() = default;
5657
};
5758

58-
/// \brief Placeholder for Scalar values until we implement these
59-
struct ARROW_EXPORT Scalar {
60-
util::variant<bool, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t,
61-
int64_t, float, double>
62-
value;
63-
64-
explicit Scalar(bool value) : value(value) {}
65-
explicit Scalar(uint8_t value) : value(value) {}
66-
explicit Scalar(int8_t value) : value(value) {}
67-
explicit Scalar(uint16_t value) : value(value) {}
68-
explicit Scalar(int16_t value) : value(value) {}
69-
explicit Scalar(uint32_t value) : value(value) {}
70-
explicit Scalar(int32_t value) : value(value) {}
71-
explicit Scalar(uint64_t value) : value(value) {}
72-
explicit Scalar(int64_t value) : value(value) {}
73-
explicit Scalar(float value) : value(value) {}
74-
explicit Scalar(double value) : value(value) {}
75-
76-
Type::type kind() const {
77-
switch (this->value.which()) {
78-
case 0:
79-
return Type::BOOL;
80-
case 1:
81-
return Type::UINT8;
82-
case 2:
83-
return Type::INT8;
84-
case 3:
85-
return Type::UINT16;
86-
case 4:
87-
return Type::INT16;
88-
case 5:
89-
return Type::UINT32;
90-
case 6:
91-
return Type::INT32;
92-
case 7:
93-
return Type::UINT64;
94-
case 8:
95-
return Type::INT64;
96-
case 9:
97-
return Type::FLOAT;
98-
case 10:
99-
return Type::DOUBLE;
100-
default:
101-
return Type::NA;
102-
}
103-
}
104-
};
105-
10659
/// \class Datum
10760
/// \brief Variant type for various Arrow C++ data structures
10861
struct ARROW_EXPORT Datum {
10962
enum type { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION };
11063

111-
util::variant<decltype(NULLPTR), Scalar, std::shared_ptr<ArrayData>,
64+
util::variant<decltype(NULLPTR), std::shared_ptr<Scalar>, std::shared_ptr<ArrayData>,
11265
std::shared_ptr<ChunkedArray>, std::shared_ptr<RecordBatch>,
11366
std::shared_ptr<Table>, std::vector<Datum>>
11467
value;
11568

11669
/// \brief Empty datum, to be populated elsewhere
11770
Datum() : value(NULLPTR) {}
11871

119-
Datum(const Scalar& value) // NOLINT implicit conversion
72+
Datum(const std::shared_ptr<Scalar>& value) // NOLINT implicit conversion
12073
: value(value) {}
12174
Datum(const std::shared_ptr<ArrayData>& value) // NOLINT implicit conversion
12275
: value(value) {}
@@ -188,14 +141,18 @@ struct ARROW_EXPORT Datum {
188141
return util::get<std::vector<Datum>>(this->value);
189142
}
190143

191-
Scalar scalar() const { return util::get<Scalar>(this->value); }
144+
std::shared_ptr<Scalar> scalar() const {
145+
return util::get<std::shared_ptr<Scalar>>(this->value);
146+
}
192147

193148
bool is_array() const { return this->kind() == Datum::ARRAY; }
194149

195150
bool is_arraylike() const {
196151
return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY;
197152
}
198153

154+
bool is_scalar() const { return this->kind() == Datum::SCALAR; }
155+
199156
/// \brief The value type of the variant, if any
200157
///
201158
/// \return nullptr if no type

cpp/src/arrow/compute/kernels/aggregate-test.cc

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "arrow/compute/kernels/sum.h"
2626
#include "arrow/compute/test-util.h"
2727
#include "arrow/type.h"
28+
#include "arrow/type_traits.h"
29+
#include "arrow/util/checked_cast.h"
2830

2931
#include "arrow/testing/gtest_common.h"
3032
#include "arrow/testing/gtest_util.h"
@@ -36,47 +38,46 @@ using std::vector;
3638
namespace arrow {
3739
namespace compute {
3840

39-
template <typename CType, typename Enable = void>
41+
template <typename Type, typename Enable = void>
4042
struct DatumEqual {
4143
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {}
4244
};
4345

44-
template <typename CType>
45-
struct DatumEqual<CType,
46-
typename std::enable_if<std::is_floating_point<CType>::value>::type> {
46+
template <typename Type>
47+
struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::Value>::type> {
4748
static constexpr double kArbitraryDoubleErrorBound = 1.0;
49+
using ScalarType = typename TypeTraits<Type>::ScalarType;
4850

4951
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
5052
ASSERT_EQ(lhs.kind(), rhs.kind());
5153
if (lhs.kind() == Datum::SCALAR) {
52-
ASSERT_EQ(lhs.scalar().kind(), rhs.scalar().kind());
53-
ASSERT_NEAR(util::get<CType>(lhs.scalar().value),
54-
util::get<CType>(rhs.scalar().value), kArbitraryDoubleErrorBound);
54+
auto left = static_cast<const ScalarType*>(lhs.scalar().get());
55+
auto right = static_cast<const ScalarType*>(rhs.scalar().get());
56+
ASSERT_EQ(left->type->id(), right->type->id());
57+
ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
5558
}
5659
}
5760
};
5861

59-
template <typename CType>
60-
struct DatumEqual<CType,
61-
typename std::enable_if<!std::is_floating_point<CType>::value>::type> {
62+
template <typename Type>
63+
struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
64+
using ScalarType = typename TypeTraits<Type>::ScalarType;
6265
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
6366
ASSERT_EQ(lhs.kind(), rhs.kind());
6467
if (lhs.kind() == Datum::SCALAR) {
65-
ASSERT_EQ(lhs.scalar().kind(), rhs.scalar().kind());
66-
ASSERT_EQ(util::get<CType>(lhs.scalar().value),
67-
util::get<CType>(rhs.scalar().value));
68+
auto left = static_cast<const ScalarType*>(lhs.scalar().get());
69+
auto right = static_cast<const ScalarType*>(rhs.scalar().get());
70+
ASSERT_EQ(left->type->id(), right->type->id());
71+
ASSERT_EQ(left->value, right->value);
6872
}
6973
}
7074
};
7175

7276
template <typename ArrowType>
7377
void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
74-
using CType = typename ArrowType::c_type;
75-
using SumType = typename FindAccumulatorType<CType>::Type;
76-
7778
Datum result;
7879
ASSERT_OK(Sum(ctx, input, &result));
79-
DatumEqual<SumType>::EnsureEqual(result, expected);
80+
DatumEqual<ArrowType>::EnsureEqual(result, expected);
8081
}
8182

8283
template <typename ArrowType>
@@ -87,11 +88,11 @@ void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
8788

8889
template <typename ArrowType>
8990
static Datum DummySum(const Array& array) {
90-
using CType = typename ArrowType::c_type;
9191
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
92-
using SumType = typename FindAccumulatorType<CType>::Type;
92+
using SumType = typename FindAccumulatorType<ArrowType>::Type;
93+
using SumScalarType = typename TypeTraits<SumType>::ScalarType;
9394

94-
SumType sum = 0;
95+
typename SumType::c_type sum = 0;
9596
int64_t count = 0;
9697

9798
const auto& array_numeric = reinterpret_cast<const ArrayType&>(array);
@@ -104,7 +105,11 @@ static Datum DummySum(const Array& array) {
104105
}
105106
}
106107

107-
return (count > 0) ? Datum(Scalar(sum)) : Datum();
108+
if (count > 0) {
109+
return Datum(std::make_shared<SumScalarType>(sum));
110+
} else {
111+
return Datum(std::make_shared<SumScalarType>(0, false));
112+
}
108113
}
109114

110115
template <typename ArrowType>
@@ -115,24 +120,23 @@ void ValidateSum(FunctionContext* ctx, const Array& array) {
115120
template <typename ArrowType>
116121
class TestSumKernelNumeric : public ComputeFixture, public TestBase {};
117122

118-
typedef ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
119-
UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>
120-
NumericArrowTypes;
121-
122123
TYPED_TEST_CASE(TestSumKernelNumeric, NumericArrowTypes);
123124
TYPED_TEST(TestSumKernelNumeric, SimpleSum) {
124-
using CType = typename TypeParam::c_type;
125-
using SumType = typename FindAccumulatorType<CType>::Type;
125+
using SumType = typename FindAccumulatorType<TypeParam>::Type;
126+
using ScalarType = typename TypeTraits<SumType>::ScalarType;
127+
using T = typename TypeParam::c_type;
126128

127-
ValidateSum<TypeParam>(&this->ctx_, "[]", Datum());
129+
ValidateSum<TypeParam>(&this->ctx_, "[]",
130+
Datum(std::make_shared<ScalarType>(0, false)));
128131

129132
ValidateSum<TypeParam>(&this->ctx_, "[0, 1, 2, 3, 4, 5]",
130-
Datum(Scalar(static_cast<SumType>(5 * 6 / 2))));
133+
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
131134

132135
// Avoid this tests for (U)Int8Type
133-
if (sizeof(CType) > 1)
136+
if (sizeof(typename TypeParam::c_type) > 1) {
134137
ValidateSum<TypeParam>(&this->ctx_, "[1000, null, 300, null, 30, null, 7]",
135-
Datum(Scalar(static_cast<SumType>(1337))));
138+
Datum(std::make_shared<ScalarType>(static_cast<T>(1337))));
139+
}
136140
}
137141

138142
template <typename ArrowType>

cpp/src/arrow/compute/kernels/sum.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
namespace arrow {
2929
namespace compute {
3030

31-
template <typename CType, typename SumType = typename FindAccumulatorType<CType>::Type>
31+
template <typename ArrowType,
32+
typename SumType = typename FindAccumulatorType<ArrowType>::Type>
3233
struct SumState {
33-
using ThisType = SumState<CType, SumType>;
34+
using ThisType = SumState<ArrowType, SumType>;
3435

3536
ThisType operator+(const ThisType& rhs) const {
3637
return ThisType(this->count + rhs.count, this->sum + rhs.sum);
@@ -43,11 +44,16 @@ struct SumState {
4344
return *this;
4445
}
4546

47+
std::shared_ptr<Scalar> AsScalar() const {
48+
using ScalarType = typename TypeTraits<SumType>::ScalarType;
49+
return std::make_shared<ScalarType>(this->sum);
50+
}
51+
4652
size_t count = 0;
47-
SumType sum = 0;
53+
typename SumType::c_type sum = 0;
4854
};
4955

50-
template <typename ArrowType, typename StateType = SumState<typename ArrowType::c_type>>
56+
template <typename ArrowType, typename StateType = SumState<ArrowType>>
5157
class SumAggregateFunction final : public AggregateFunctionStaticState<StateType> {
5258
using CType = typename TypeTraits<ArrowType>::CType;
5359
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
@@ -71,7 +77,12 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
7177
}
7278

7379
Status Finalize(const StateType& src, Datum* output) const override {
74-
*output = (src.count > 0) ? Datum(Scalar(src.sum)) : Datum();
80+
auto boxed = src.AsScalar();
81+
if (src.count == 0) {
82+
// TODO(wesm): Currently null, but fix this
83+
boxed->is_valid = false;
84+
}
85+
*output = boxed;
7586
return Status::OK();
7687
}
7788

@@ -185,7 +196,7 @@ Status Sum(FunctionContext* ctx, const Datum& value, Datum* out) {
185196
}
186197

187198
Status Sum(FunctionContext* ctx, const Array& array, Datum* out) {
188-
return Sum(ctx, Datum(array.data()), out);
199+
return Sum(ctx, array.data(), out);
189200
}
190201

191202
} // namespace compute

cpp/src/arrow/compute/kernels/sum.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <type_traits>
2323

2424
#include "arrow/status.h"
25+
#include "arrow/type.h"
26+
#include "arrow/type_traits.h"
2527
#include "arrow/util/visibility.h"
2628

2729
namespace arrow {
@@ -34,25 +36,22 @@ namespace compute {
3436
// Find the largest compatible primitive type for a primitive type.
3537
template <typename I, typename Enable = void>
3638
struct FindAccumulatorType {
37-
using Type = double;
39+
using Type = DoubleType;
3840
};
3941

4042
template <typename I>
41-
struct FindAccumulatorType<I, typename std::enable_if<std::is_integral<I>::value &&
42-
std::is_signed<I>::value>::type> {
43-
using Type = int64_t;
43+
struct FindAccumulatorType<I, typename std::enable_if<IsSignedInt<I>::value>::type> {
44+
using Type = Int64Type;
4445
};
4546

4647
template <typename I>
47-
struct FindAccumulatorType<I, typename std::enable_if<std::is_integral<I>::value &&
48-
std::is_unsigned<I>::value>::type> {
49-
using Type = uint64_t;
48+
struct FindAccumulatorType<I, typename std::enable_if<IsUnsignedInt<I>::value>::type> {
49+
using Type = UInt64Type;
5050
};
5151

5252
template <typename I>
53-
struct FindAccumulatorType<
54-
I, typename std::enable_if<std::is_floating_point<I>::value>::type> {
55-
using Type = double;
53+
struct FindAccumulatorType<I, typename std::enable_if<IsFloatingPoint<I>::value>::type> {
54+
using Type = DoubleType;
5655
};
5756

5857
struct Datum;

0 commit comments

Comments
 (0)