2525#include  " arrow/compute/kernels/sum.h" 
2626#include  " arrow/compute/test-util.h" 
2727#include  " arrow/type.h" 
28+ #include  " arrow/type_traits.h" 
29+ #include  " arrow/util/checked_cast.h" 
2830
2931#include  " arrow/testing/gtest_common.h" 
3032#include  " arrow/testing/gtest_util.h" 
@@ -36,47 +38,47 @@ using std::vector;
3638namespace  arrow  {
3739namespace  compute  {
3840
39- template  <typename  CType , typename  Enable = void >
41+ template  <typename  Type , typename  Enable = void >
4042struct  DatumEqual  {
4143  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {}
4244};
4345
44- template  <typename  CType>
45- struct  DatumEqual <CType,
46-                   typename  std::enable_if<std::is_floating_point<CType>::value>::type> {
46+ template  <typename  Type>
47+ struct  DatumEqual <Type, typename  std::enable_if<IsFloatingPoint<Type>::Value>::type> {
4748  static  constexpr  double  kArbitraryDoubleErrorBound  = 1.0 ;
49+   using  ScalarType = typename  TypeTraits<Type>::ScalarType;
4850
4951  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {
5052    ASSERT_EQ (lhs.kind (), rhs.kind ());
5153    if  (lhs.kind () == Datum::SCALAR) {
52-       ASSERT_EQ (lhs.scalar ().kind (), rhs.scalar ().kind ());
53-       ASSERT_NEAR (util::get<CType>(lhs.scalar ().value ),
54-                   util::get<CType>(rhs.scalar ().value ), kArbitraryDoubleErrorBound );
54+       auto  left = static_cast <const  ScalarType*>(lhs.scalar ().get ());
55+       auto  right = static_cast <const  ScalarType*>(rhs.scalar ().get ());
56+       ASSERT_EQ (left->type ->id (), right->type ->id ());
57+       ASSERT_NEAR (left->value , right->value , kArbitraryDoubleErrorBound );
5558    }
5659  }
5760};
5861
59- template  <typename  CType >
60- struct  DatumEqual <CType, 
61-                    typename  std::enable_if<!std::is_floating_point<CType >::value>::type> { 
62+ template  <typename  Type >
63+ struct  DatumEqual <Type,  typename  std::enable_if<!IsFloatingPoint<Type>::value>::type> { 
64+   using  ScalarType =  typename  TypeTraits<Type >::ScalarType; 
6265  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {
6366    ASSERT_EQ (lhs.kind (), rhs.kind ());
6467    if  (lhs.kind () == Datum::SCALAR) {
65-       ASSERT_EQ (lhs.scalar ().kind (), rhs.scalar ().kind ());
66-       ASSERT_EQ (util::get<CType>(lhs.scalar ().value ),
67-                 util::get<CType>(rhs.scalar ().value ));
68+       auto  left = static_cast <const  ScalarType*>(lhs.scalar ().get ());
69+       auto  right = static_cast <const  ScalarType*>(rhs.scalar ().get ());
70+       ASSERT_EQ (left->type ->id (), right->type ->id ());
71+       ASSERT_EQ (left->value , right->value );
6872    }
6973  }
7074};
7175
7276template  <typename  ArrowType>
7377void  ValidateSum (FunctionContext* ctx, const  Array& input, Datum expected) {
74-   using  CType = typename  ArrowType::c_type;
75-   using  SumType = typename  FindAccumulatorType<CType>::Type;
76- 
78+   using  OutputType = typename  FindAccumulatorType<ArrowType>::Type;
7779  Datum result;
7880  ASSERT_OK (Sum (ctx, input, &result));
79-   DatumEqual<SumType >::EnsureEqual (result, expected);
81+   DatumEqual<OutputType >::EnsureEqual (result, expected);
8082}
8183
8284template  <typename  ArrowType>
@@ -87,11 +89,11 @@ void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
8789
8890template  <typename  ArrowType>
8991static  Datum DummySum (const  Array& array) {
90-   using  CType = typename  ArrowType::c_type;
9192  using  ArrayType = typename  TypeTraits<ArrowType>::ArrayType;
92-   using  SumType = typename  FindAccumulatorType<CType>::Type;
93+   using  SumType = typename  FindAccumulatorType<ArrowType>::Type;
94+   using  SumScalarType = typename  TypeTraits<SumType>::ScalarType;
9395
94-   SumType sum = 0 ;
96+   typename   SumType::c_type  sum = 0 ;
9597  int64_t  count = 0 ;
9698
9799  const  auto & array_numeric = reinterpret_cast <const  ArrayType&>(array);
@@ -104,7 +106,11 @@ static Datum DummySum(const Array& array) {
104106    }
105107  }
106108
107-   return  (count > 0 ) ? Datum (Scalar (sum)) : Datum ();
109+   if  (count > 0 ) {
110+     return  Datum (std::make_shared<SumScalarType>(sum));
111+   } else  {
112+     return  Datum (std::make_shared<SumScalarType>(0 , false ));
113+   }
108114}
109115
110116template  <typename  ArrowType>
@@ -115,24 +121,21 @@ void ValidateSum(FunctionContext* ctx, const Array& array) {
115121template  <typename  ArrowType>
116122class  TestSumKernelNumeric  : public  ComputeFixture , public  TestBase  {};
117123
118- typedef  ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
119-                          UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>
120-     NumericArrowTypes;
121- 
122124TYPED_TEST_CASE (TestSumKernelNumeric, NumericArrowTypes);
123125TYPED_TEST (TestSumKernelNumeric, SimpleSum) {
124-   using  CType = typename  TypeParam::c_type;
125-   using  SumType = typename  FindAccumulatorType<CType>::Type;
126+   using  SumType = typename  FindAccumulatorType<TypeParam>::Type;
127+   using  ScalarType = typename  TypeTraits<SumType>::ScalarType;
128+   using  T = typename  TypeParam::c_type;
126129
127-   ValidateSum<TypeParam>(&this ->ctx_ , " []" Datum ());
130+   ValidateSum<TypeParam>(&this ->ctx_ , " []" 
131+                          Datum (std::make_shared<ScalarType>(0 , false )));
128132
129133  ValidateSum<TypeParam>(&this ->ctx_ , " [0, 1, 2, 3, 4, 5]" 
130-                          Datum (Scalar (static_cast <SumType >(5  * 6  / 2 ))));
134+                          Datum (std::make_shared<ScalarType> (static_cast <T >(5  * 6  / 2 ))));
131135
132-   //  Avoid this tests for (U)Int8Type
133-   if  (sizeof (CType) > 1 )
134-     ValidateSum<TypeParam>(&this ->ctx_ , " [1000, null, 300, null, 30, null, 7]" 
135-                            Datum (Scalar (static_cast <SumType>(1337 ))));
136+   const  T expected_result = static_cast <T>(14 );
137+   ValidateSum<TypeParam>(&this ->ctx_ , " [1, null, 3, null, 3, null, 7]" 
138+                          Datum (std::make_shared<ScalarType>(expected_result)));
136139}
137140
138141template  <typename  ArrowType>
0 commit comments