2525#include  " arrow/compute/kernels/sum.h" 
2626#include  " arrow/compute/test-util.h" 
2727#include  " arrow/type.h" 
28+ #include  " arrow/type_traits.h" 
29+ #include  " arrow/util/checked_cast.h" 
2830
2931#include  " arrow/testing/gtest_common.h" 
3032#include  " arrow/testing/gtest_util.h" 
@@ -36,47 +38,46 @@ using std::vector;
3638namespace  arrow  {
3739namespace  compute  {
3840
39- template  <typename  CType , typename  Enable = void >
41+ template  <typename  Type , typename  Enable = void >
4042struct  DatumEqual  {
4143  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {}
4244};
4345
44- template  <typename  CType>
45- struct  DatumEqual <CType,
46-                   typename  std::enable_if<std::is_floating_point<CType>::value>::type> {
46+ template  <typename  Type>
47+ struct  DatumEqual <Type, typename  std::enable_if<IsFloatingPoint<Type>::Value>::type> {
4748  static  constexpr  double  kArbitraryDoubleErrorBound  = 1.0 ;
49+   using  ScalarType = typename  TypeTraits<Type>::ScalarType;
4850
4951  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {
5052    ASSERT_EQ (lhs.kind (), rhs.kind ());
5153    if  (lhs.kind () == Datum::SCALAR) {
52-       ASSERT_EQ (lhs.scalar ().kind (), rhs.scalar ().kind ());
53-       ASSERT_NEAR (util::get<CType>(lhs.scalar ().value ),
54-                   util::get<CType>(rhs.scalar ().value ), kArbitraryDoubleErrorBound );
54+       auto  left = static_cast <const  ScalarType*>(lhs.scalar ().get ());
55+       auto  right = static_cast <const  ScalarType*>(rhs.scalar ().get ());
56+       ASSERT_EQ (left->type ->id (), right->type ->id ());
57+       ASSERT_NEAR (left->value , right->value , kArbitraryDoubleErrorBound );
5558    }
5659  }
5760};
5861
59- template  <typename  CType >
60- struct  DatumEqual <CType, 
61-                    typename  std::enable_if<!std::is_floating_point<CType >::value>::type> { 
62+ template  <typename  Type >
63+ struct  DatumEqual <Type,  typename  std::enable_if<!IsFloatingPoint<Type>::value>::type> { 
64+   using  ScalarType =  typename  TypeTraits<Type >::ScalarType; 
6265  static  void  EnsureEqual (const  Datum& lhs, const  Datum& rhs) {
6366    ASSERT_EQ (lhs.kind (), rhs.kind ());
6467    if  (lhs.kind () == Datum::SCALAR) {
65-       ASSERT_EQ (lhs.scalar ().kind (), rhs.scalar ().kind ());
66-       ASSERT_EQ (util::get<CType>(lhs.scalar ().value ),
67-                 util::get<CType>(rhs.scalar ().value ));
68+       auto  left = static_cast <const  ScalarType*>(lhs.scalar ().get ());
69+       auto  right = static_cast <const  ScalarType*>(rhs.scalar ().get ());
70+       ASSERT_EQ (left->type ->id (), right->type ->id ());
71+       ASSERT_EQ (left->value , right->value );
6872    }
6973  }
7074};
7175
7276template  <typename  ArrowType>
7377void  ValidateSum (FunctionContext* ctx, const  Array& input, Datum expected) {
74-   using  CType = typename  ArrowType::c_type;
75-   using  SumType = typename  FindAccumulatorType<CType>::Type;
76- 
7778  Datum result;
7879  ASSERT_OK (Sum (ctx, input, &result));
79-   DatumEqual<SumType >::EnsureEqual (result, expected);
80+   DatumEqual<ArrowType >::EnsureEqual (result, expected);
8081}
8182
8283template  <typename  ArrowType>
@@ -87,11 +88,11 @@ void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
8788
8889template  <typename  ArrowType>
8990static  Datum DummySum (const  Array& array) {
90-   using  CType = typename  ArrowType::c_type;
9191  using  ArrayType = typename  TypeTraits<ArrowType>::ArrayType;
92-   using  SumType = typename  FindAccumulatorType<CType>::Type;
92+   using  SumType = typename  FindAccumulatorType<ArrowType>::Type;
93+   using  SumScalarType = typename  TypeTraits<SumType>::ScalarType;
9394
94-   SumType sum = 0 ;
95+   typename   SumType::c_type  sum = 0 ;
9596  int64_t  count = 0 ;
9697
9798  const  auto & array_numeric = reinterpret_cast <const  ArrayType&>(array);
@@ -104,7 +105,11 @@ static Datum DummySum(const Array& array) {
104105    }
105106  }
106107
107-   return  (count > 0 ) ? Datum (Scalar (sum)) : Datum ();
108+   if  (count > 0 ) {
109+     return  Datum (std::make_shared<SumScalarType>(sum));
110+   } else  {
111+     return  Datum (std::make_shared<SumScalarType>(0 , false ));
112+   }
108113}
109114
110115template  <typename  ArrowType>
@@ -115,24 +120,23 @@ void ValidateSum(FunctionContext* ctx, const Array& array) {
115120template  <typename  ArrowType>
116121class  TestSumKernelNumeric  : public  ComputeFixture , public  TestBase  {};
117122
118- typedef  ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
119-                          UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>
120-     NumericArrowTypes;
121- 
122123TYPED_TEST_CASE (TestSumKernelNumeric, NumericArrowTypes);
123124TYPED_TEST (TestSumKernelNumeric, SimpleSum) {
124-   using  CType = typename  TypeParam::c_type;
125-   using  SumType = typename  FindAccumulatorType<CType>::Type;
125+   using  SumType = typename  FindAccumulatorType<TypeParam>::Type;
126+   using  ScalarType = typename  TypeTraits<SumType>::ScalarType;
127+   using  T = typename  TypeParam::c_type;
126128
127-   ValidateSum<TypeParam>(&this ->ctx_ , " []" Datum ());
129+   ValidateSum<TypeParam>(&this ->ctx_ , " []" 
130+                          Datum (std::make_shared<ScalarType>(0 , false )));
128131
129132  ValidateSum<TypeParam>(&this ->ctx_ , " [0, 1, 2, 3, 4, 5]" 
130-                          Datum (Scalar (static_cast <SumType >(5  * 6  / 2 ))));
133+                          Datum (std::make_shared<ScalarType> (static_cast <T >(5  * 6  / 2 ))));
131134
132135  //  Avoid this tests for (U)Int8Type
133-   if  (sizeof (CType ) > 1 )
136+   if  (sizeof (typename  TypeParam::c_type ) > 1 ) { 
134137    ValidateSum<TypeParam>(&this ->ctx_ , " [1000, null, 300, null, 30, null, 7]" 
135-                            Datum (Scalar (static_cast <SumType>(1337 ))));
138+                            Datum (std::make_shared<ScalarType>(static_cast <T>(1337 ))));
139+   }
136140}
137141
138142template  <typename  ArrowType>
0 commit comments