1717 */
1818package org .apache .beam .dsls .sql .transform ;
1919
20+ import java .math .BigDecimal ;
2021import java .sql .Types ;
2122import java .util .Arrays ;
2223import java .util .List ;
2627import org .apache .beam .dsls .sql .utils .CalciteUtils ;
2728
2829/**
29- * Build -in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
30+ * Built -in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
3031 */
3132class BeamBuiltinAggregations {
3233 /**
33- * Build -in aggregation for COUNT.
34+ * Built -in aggregation for COUNT.
3435 */
3536 public static class Count <T > extends BeamSqlUdaf <T , Long > {
3637 private BeamSqlRecordType accType ;
@@ -65,7 +66,7 @@ public Long result(BeamSqlRow accumulator) {
6566 }
6667
6768 /**
68- * Build -in aggregation for MAX.
69+ * Built -in aggregation for MAX.
6970 */
7071 public static class Max <T extends Comparable <T >> extends BeamSqlUdaf <T , T > {
7172 private BeamSqlRecordType accType ;
@@ -106,7 +107,7 @@ public T result(BeamSqlRow accumulator) {
106107 }
107108
108109 /**
109- * Build -in aggregation for MIN.
110+ * Built -in aggregation for MIN.
110111 */
111112 public static class Min <T extends Comparable <T >> extends BeamSqlUdaf <T , T > {
112113 private BeamSqlRecordType accType ;
@@ -147,12 +148,12 @@ public T result(BeamSqlRow accumulator) {
147148 }
148149
149150 /**
150- * Build -in aggregation for SUM.
151+ * Built -in aggregation for SUM.
151152 */
152153 public static class Sum <T > extends BeamSqlUdaf <T , T > {
153154 private static List <Integer > supportedType = Arrays .asList (Types .INTEGER ,
154155 Types .BIGINT , Types .SMALLINT , Types .TINYINT , Types .DOUBLE ,
155- Types .FLOAT );
156+ Types .FLOAT , Types . DECIMAL );
156157
157158 private int outputFieldType ;
158159 private BeamSqlRecordType accType ;
@@ -165,68 +166,69 @@ public Sum(int outputFieldType) {
165166
166167 this .outputFieldType = outputFieldType ;
167168 this .accType = BeamSqlRecordType .create (Arrays .asList ("__sum" ),
168- Arrays .asList (Types .DOUBLE )); //by default use DOUBLE to store the value.
169+ Arrays .asList (Types .DECIMAL )); //by default use DOUBLE to store the value.
169170 }
170171
171172 @ Override
172173 public BeamSqlRow init () {
173- return new BeamSqlRow (accType , Arrays .<Object >asList (0.0 ));
174+ return new BeamSqlRow (accType , Arrays .<Object >asList (new BigDecimal ( 0 ) ));
174175 }
175176
176177 @ Override
177178 public BeamSqlRow add (BeamSqlRow accumulator , T input ) {
178- return new BeamSqlRow (accType , Arrays .<Object >asList (accumulator .getDouble (0 )
179- + Double . valueOf ( input .toString ())));
179+ return new BeamSqlRow (accType , Arrays .<Object >asList (accumulator .getBigDecimal (0 )
180+ . add ( new BigDecimal ( input .toString () ))));
180181 }
181182
182183 @ Override
183184 public BeamSqlRow merge (Iterable <BeamSqlRow > accumulators ) {
184- double v = 0.0 ;
185+ BigDecimal v = new BigDecimal ( 0 ) ;
185186 while (accumulators .iterator ().hasNext ()) {
186- v += accumulators .iterator ().next ().getDouble ( 0 );
187+ v . add ( accumulators .iterator ().next ().getBigDecimal ( 0 ) );
187188 }
188189 return new BeamSqlRow (accType , Arrays .<Object >asList (v ));
189190 }
190191
191192 @ Override
192193 public T result (BeamSqlRow accumulator ) {
193- BeamSqlRow result = new BeamSqlRow (
194- BeamSqlRecordType .create (Arrays .asList ("__sum" ), Arrays .asList (outputFieldType )));
194+ Object result = null ;
195195 switch (outputFieldType ) {
196196 case Types .INTEGER :
197- result . addField ( 0 , ( int ) accumulator .getDouble (0 ));
197+ result = accumulator .getBigDecimal (0 ). intValue ( );
198198 break ;
199199 case Types .BIGINT :
200- result . addField ( 0 , ( long ) accumulator .getDouble (0 ));
200+ result = accumulator .getBigDecimal (0 ). longValue ( );
201201 break ;
202202 case Types .SMALLINT :
203- result . addField ( 0 , ( short ) accumulator .getDouble (0 ));
203+ result = accumulator .getBigDecimal (0 ). shortValue ( );
204204 break ;
205205 case Types .TINYINT :
206- result . addField ( 0 , ( byte ) accumulator .getDouble (0 ));
206+ result = accumulator .getBigDecimal (0 ). byteValue ( );
207207 break ;
208208 case Types .DOUBLE :
209- result . addField ( 0 , accumulator .getDouble (0 ));
209+ result = accumulator .getBigDecimal (0 ). doubleValue ( );
210210 break ;
211211 case Types .FLOAT :
212- result .addField (0 , (float ) accumulator .getDouble (0 ));
212+ result = accumulator .getBigDecimal (0 ).floatValue ();
213+ break ;
214+ case Types .DECIMAL :
215+ result = accumulator .getBigDecimal (0 );
213216 break ;
214-
215217 default :
216218 break ;
217219 }
218- return (T ) result . getFieldValue ( 0 ) ;
220+ return (T ) result ;
219221 }
220222
221223 }
222224
223225 /**
224- * Build -in aggregation for AVG.
226+ * Built -in aggregation for AVG.
225227 */
226228 public static class Avg <T > extends BeamSqlUdaf <T , T > {
227229 private static List <Integer > supportedType = Arrays .asList (Types .INTEGER ,
228230 Types .BIGINT , Types .SMALLINT , Types .TINYINT , Types .DOUBLE ,
229- Types .FLOAT );
231+ Types .FLOAT , Types . DECIMAL );
230232
231233 private int outputFieldType ;
232234 private BeamSqlRecordType accType ;
@@ -239,63 +241,65 @@ public Avg(int outputFieldType) {
239241
240242 this .outputFieldType = outputFieldType ;
241243 this .accType = BeamSqlRecordType .create (Arrays .asList ("__sum" , "size" ),
242- Arrays .asList (Types .DOUBLE , Types .BIGINT )); //by default use DOUBLE to store the value.
244+ Arrays .asList (Types .DECIMAL , Types .BIGINT )); //by default use DOUBLE to store the value.
243245 }
244246
245247 @ Override
246248 public BeamSqlRow init () {
247- return new BeamSqlRow (accType , Arrays .<Object >asList (0.0 , 0L ));
249+ return new BeamSqlRow (accType , Arrays .<Object >asList (new BigDecimal ( 0 ) , 0L ));
248250 }
249251
250252 @ Override
251253 public BeamSqlRow add (BeamSqlRow accumulator , T input ) {
252254 return new BeamSqlRow (accType ,
253255 Arrays .<Object >asList (
254- accumulator .getDouble (0 )
255- + Double .valueOf (input .toString ()),
256+ accumulator .getBigDecimal (0 ).add (new BigDecimal (input .toString ())),
256257 accumulator .getLong (1 ) + 1 ));
257258 }
258259
259260 @ Override
260261 public BeamSqlRow merge (Iterable <BeamSqlRow > accumulators ) {
261- double v = 0.0 ;
262+ BigDecimal v = new BigDecimal ( 0 ) ;
262263 long s = 0 ;
263264 while (accumulators .iterator ().hasNext ()) {
264265 BeamSqlRow r = accumulators .iterator ().next ();
265- v += r . getDouble ( 0 );
266+ v . add ( r . getBigDecimal ( 0 ) );
266267 s += r .getLong (1 );
267268 }
268269 return new BeamSqlRow (accType , Arrays .<Object >asList (v , s ));
269270 }
270271
271272 @ Override
272273 public T result (BeamSqlRow accumulator ) {
273- BeamSqlRow result = new BeamSqlRow (
274- BeamSqlRecordType .create (Arrays .asList ("__avg" ), Arrays .asList (outputFieldType )));
274+ Object result = null ;
275+ BigDecimal decimalAvg = accumulator .getBigDecimal (0 ).divide (
276+ new BigDecimal (accumulator .getLong (1 )));
275277 switch (outputFieldType ) {
276278 case Types .INTEGER :
277- result . addField ( 0 , ( int ) ( accumulator . getDouble ( 0 ) / accumulator . getLong ( 1 )) );
279+ result = decimalAvg . intValue ( );
278280 break ;
279281 case Types .BIGINT :
280- result . addField ( 0 , ( long ) ( accumulator . getDouble ( 0 ) / accumulator . getLong ( 1 )) );
282+ result = decimalAvg . longValue ( );
281283 break ;
282284 case Types .SMALLINT :
283- result . addField ( 0 , ( short ) ( accumulator . getDouble ( 0 ) / accumulator . getLong ( 1 )) );
285+ result = decimalAvg . shortValue ( );
284286 break ;
285287 case Types .TINYINT :
286- result . addField ( 0 , ( byte ) ( accumulator . getDouble ( 0 ) / accumulator . getLong ( 1 )) );
288+ result = decimalAvg . byteValue ( );
287289 break ;
288290 case Types .DOUBLE :
289- result . addField ( 0 , accumulator . getDouble ( 0 ) / accumulator . getLong ( 1 ) );
291+ result = decimalAvg . doubleValue ( );
290292 break ;
291293 case Types .FLOAT :
292- result .addField (0 , (float ) (accumulator .getDouble (0 ) / accumulator .getLong (1 )));
294+ result = decimalAvg .floatValue ();
295+ break ;
296+ case Types .DECIMAL :
297+ result = decimalAvg ;
293298 break ;
294-
295299 default :
296300 break ;
297301 }
298- return (T ) result . getFieldValue ( 0 ) ;
302+ return (T ) result ;
299303 }
300304
301305 }
0 commit comments