Skip to content

Commit d42581d

Browse files
author
mingmxu
committed
1. support DECIMAL in built-in aggregators;
2. add JavaDoc for BeamSqlUdaf;
1 parent 0fc4724 commit d42581d

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,41 @@
1818
package org.apache.beam.dsls.sql.schema;
1919

2020
import java.io.Serializable;
21+
import org.apache.beam.sdk.transforms.Combine.CombineFn;
2122

2223
/**
2324
* abstract class of aggregation functions in Beam SQL.
2425
*
2526
* <p>There're several constrains for a UDAF:<br>
2627
* 1. A constructor with an empty argument list is required;<br>
2728
* 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double
28-
* /Float/Date, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT/TIMESTAMP;<br>
29+
* /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT
30+
* /TIMESTAMP/DECIMAL;<br>
2931
* 3. wrap intermediate data in a {@link BeamSqlRow}, and do not rely on elements in class;<br>
32+
* 4. The intermediate value of UDAF function is stored in a {@code BeamSqlRow} object.<br>
3033
*/
3134
public abstract class BeamSqlUdaf<InputT, OutputT> implements Serializable {
3235
public BeamSqlUdaf(){}
3336

37+
/**
38+
* create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}.
39+
*/
3440
public abstract BeamSqlRow init();
3541

42+
/**
43+
* add an input value, equals to {@link CombineFn#addInput(Object, Object)}.
44+
*/
3645
public abstract BeamSqlRow add(BeamSqlRow accumulator, InputT input);
3746

47+
/**
48+
* merge aggregation objects from parallel tasks, equals to
49+
* {@link CombineFn#mergeAccumulators(Iterable)}.
50+
*/
3851
public abstract BeamSqlRow merge(Iterable<BeamSqlRow> accumulators);
3952

53+
/**
54+
* extract output value from aggregation object, equals to
55+
* {@link CombineFn#extractOutput(Object)}.
56+
*/
4057
public abstract OutputT result(BeamSqlRow accumulator);
4158
}

dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.beam.dsls.sql.transform;
1919

2020
import java.io.Serializable;
21+
import java.math.BigDecimal;
2122
import java.util.ArrayList;
2223
import java.util.Date;
2324
import java.util.List;
@@ -184,6 +185,9 @@ public AggregationAdaptor(List<AggregateCall> aggregationCalls,
184185
case TIMESTAMP:
185186
aggregators.add(new BeamBuiltinAggregations.Max<Date>(outFieldType));
186187
break;
188+
case DECIMAL:
189+
aggregators.add(new BeamBuiltinAggregations.Max<BigDecimal>(outFieldType));
190+
break;
187191
default:
188192
throw new UnsupportedOperationException();
189193
}
@@ -211,6 +215,9 @@ public AggregationAdaptor(List<AggregateCall> aggregationCalls,
211215
case TIMESTAMP:
212216
aggregators.add(new BeamBuiltinAggregations.Min<Date>(outFieldType));
213217
break;
218+
case DECIMAL:
219+
aggregators.add(new BeamBuiltinAggregations.Min<BigDecimal>(outFieldType));
220+
break;
214221
default:
215222
throw new UnsupportedOperationException();
216223
}
@@ -235,6 +242,9 @@ public AggregationAdaptor(List<AggregateCall> aggregationCalls,
235242
case DOUBLE:
236243
aggregators.add(new BeamBuiltinAggregations.Sum<Double>(outFieldType));
237244
break;
245+
case DECIMAL:
246+
aggregators.add(new BeamBuiltinAggregations.Sum<BigDecimal>(outFieldType));
247+
break;
238248
default:
239249
throw new UnsupportedOperationException();
240250
}
@@ -259,6 +269,9 @@ public AggregationAdaptor(List<AggregateCall> aggregationCalls,
259269
case DOUBLE:
260270
aggregators.add(new BeamBuiltinAggregations.Avg<Double>(outFieldType));
261271
break;
272+
case DECIMAL:
273+
aggregators.add(new BeamBuiltinAggregations.Avg<BigDecimal>(outFieldType));
274+
break;
262275
default:
263276
throw new UnsupportedOperationException();
264277
}

dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.dsls.sql.transform;
1919

20+
import java.math.BigDecimal;
2021
import java.sql.Types;
2122
import java.util.Arrays;
2223
import java.util.List;
@@ -26,11 +27,11 @@
2627
import org.apache.beam.dsls.sql.utils.CalciteUtils;
2728

2829
/**
29-
* Build-in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
30+
* Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
3031
*/
3132
class BeamBuiltinAggregations {
3233
/**
33-
* Build-in aggregation for COUNT.
34+
* Built-in aggregation for COUNT.
3435
*/
3536
public static class Count<T> extends BeamSqlUdaf<T, Long> {
3637
private BeamSqlRecordType accType;
@@ -65,7 +66,7 @@ public Long result(BeamSqlRow accumulator) {
6566
}
6667

6768
/**
68-
* Build-in aggregation for MAX.
69+
* Built-in aggregation for MAX.
6970
*/
7071
public static class Max<T extends Comparable<T>> extends BeamSqlUdaf<T, T> {
7172
private BeamSqlRecordType accType;
@@ -106,7 +107,7 @@ public T result(BeamSqlRow accumulator) {
106107
}
107108

108109
/**
109-
* Build-in aggregation for MIN.
110+
* Built-in aggregation for MIN.
110111
*/
111112
public static class Min<T extends Comparable<T>> extends BeamSqlUdaf<T, T> {
112113
private BeamSqlRecordType accType;
@@ -147,12 +148,12 @@ public T result(BeamSqlRow accumulator) {
147148
}
148149

149150
/**
150-
* Build-in aggregation for SUM.
151+
* Built-in aggregation for SUM.
151152
*/
152153
public static class Sum<T> extends BeamSqlUdaf<T, T> {
153154
private static List<Integer> supportedType = Arrays.asList(Types.INTEGER,
154155
Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE,
155-
Types.FLOAT);
156+
Types.FLOAT, Types.DECIMAL);
156157

157158
private int outputFieldType;
158159
private BeamSqlRecordType accType;
@@ -165,68 +166,69 @@ public Sum(int outputFieldType) {
165166

166167
this.outputFieldType = outputFieldType;
167168
this.accType = BeamSqlRecordType.create(Arrays.asList("__sum"),
168-
Arrays.asList(Types.DOUBLE)); //by default use DOUBLE to store the value.
169+
Arrays.asList(Types.DECIMAL)); //by default use DOUBLE to store the value.
169170
}
170171

171172
@Override
172173
public BeamSqlRow init() {
173-
return new BeamSqlRow(accType, Arrays.<Object>asList(0.0));
174+
return new BeamSqlRow(accType, Arrays.<Object>asList(new BigDecimal(0)));
174175
}
175176

176177
@Override
177178
public BeamSqlRow add(BeamSqlRow accumulator, T input) {
178-
return new BeamSqlRow(accType, Arrays.<Object>asList(accumulator.getDouble(0)
179-
+ Double.valueOf(input.toString())));
179+
return new BeamSqlRow(accType, Arrays.<Object>asList(accumulator.getBigDecimal(0)
180+
.add(new BigDecimal(input.toString()))));
180181
}
181182

182183
@Override
183184
public BeamSqlRow merge(Iterable<BeamSqlRow> accumulators) {
184-
double v = 0.0;
185+
BigDecimal v = new BigDecimal(0);
185186
while (accumulators.iterator().hasNext()) {
186-
v += accumulators.iterator().next().getDouble(0);
187+
v.add(accumulators.iterator().next().getBigDecimal(0));
187188
}
188189
return new BeamSqlRow(accType, Arrays.<Object>asList(v));
189190
}
190191

191192
@Override
192193
public T result(BeamSqlRow accumulator) {
193-
BeamSqlRow result = new BeamSqlRow(
194-
BeamSqlRecordType.create(Arrays.asList("__sum"), Arrays.asList(outputFieldType)));
194+
Object result = null;
195195
switch (outputFieldType) {
196196
case Types.INTEGER:
197-
result.addField(0, (int) accumulator.getDouble(0));
197+
result = accumulator.getBigDecimal(0).intValue();
198198
break;
199199
case Types.BIGINT:
200-
result.addField(0, (long) accumulator.getDouble(0));
200+
result = accumulator.getBigDecimal(0).longValue();
201201
break;
202202
case Types.SMALLINT:
203-
result.addField(0, (short) accumulator.getDouble(0));
203+
result = accumulator.getBigDecimal(0).shortValue();
204204
break;
205205
case Types.TINYINT:
206-
result.addField(0, (byte) accumulator.getDouble(0));
206+
result = accumulator.getBigDecimal(0).byteValue();
207207
break;
208208
case Types.DOUBLE:
209-
result.addField(0, accumulator.getDouble(0));
209+
result = accumulator.getBigDecimal(0).doubleValue();
210210
break;
211211
case Types.FLOAT:
212-
result.addField(0, (float) accumulator.getDouble(0));
212+
result = accumulator.getBigDecimal(0).floatValue();
213+
break;
214+
case Types.DECIMAL:
215+
result = accumulator.getBigDecimal(0);
213216
break;
214-
215217
default:
216218
break;
217219
}
218-
return (T) result.getFieldValue(0);
220+
return (T) result;
219221
}
220222

221223
}
222224

223225
/**
224-
* Build-in aggregation for AVG.
226+
* Built-in aggregation for AVG.
225227
*/
226228
public static class Avg<T> extends BeamSqlUdaf<T, T> {
227229
private static List<Integer> supportedType = Arrays.asList(Types.INTEGER,
228230
Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE,
229-
Types.FLOAT);
231+
Types.FLOAT, Types.DECIMAL);
230232

231233
private int outputFieldType;
232234
private BeamSqlRecordType accType;
@@ -239,63 +241,65 @@ public Avg(int outputFieldType) {
239241

240242
this.outputFieldType = outputFieldType;
241243
this.accType = BeamSqlRecordType.create(Arrays.asList("__sum", "size"),
242-
Arrays.asList(Types.DOUBLE, Types.BIGINT)); //by default use DOUBLE to store the value.
244+
Arrays.asList(Types.DECIMAL, Types.BIGINT)); //by default use DOUBLE to store the value.
243245
}
244246

245247
@Override
246248
public BeamSqlRow init() {
247-
return new BeamSqlRow(accType, Arrays.<Object>asList(0.0, 0L));
249+
return new BeamSqlRow(accType, Arrays.<Object>asList(new BigDecimal(0), 0L));
248250
}
249251

250252
@Override
251253
public BeamSqlRow add(BeamSqlRow accumulator, T input) {
252254
return new BeamSqlRow(accType,
253255
Arrays.<Object>asList(
254-
accumulator.getDouble(0)
255-
+ Double.valueOf(input.toString()),
256+
accumulator.getBigDecimal(0).add(new BigDecimal(input.toString())),
256257
accumulator.getLong(1) + 1));
257258
}
258259

259260
@Override
260261
public BeamSqlRow merge(Iterable<BeamSqlRow> accumulators) {
261-
double v = 0.0;
262+
BigDecimal v = new BigDecimal(0);
262263
long s = 0;
263264
while (accumulators.iterator().hasNext()) {
264265
BeamSqlRow r = accumulators.iterator().next();
265-
v += r.getDouble(0);
266+
v.add(r.getBigDecimal(0));
266267
s += r.getLong(1);
267268
}
268269
return new BeamSqlRow(accType, Arrays.<Object>asList(v, s));
269270
}
270271

271272
@Override
272273
public T result(BeamSqlRow accumulator) {
273-
BeamSqlRow result = new BeamSqlRow(
274-
BeamSqlRecordType.create(Arrays.asList("__avg"), Arrays.asList(outputFieldType)));
274+
Object result = null;
275+
BigDecimal decimalAvg = accumulator.getBigDecimal(0).divide(
276+
new BigDecimal(accumulator.getLong(1)));
275277
switch (outputFieldType) {
276278
case Types.INTEGER:
277-
result.addField(0, (int) (accumulator.getDouble(0) / accumulator.getLong(1)));
279+
result = decimalAvg.intValue();
278280
break;
279281
case Types.BIGINT:
280-
result.addField(0, (long) (accumulator.getDouble(0) / accumulator.getLong(1)));
282+
result = decimalAvg.longValue();
281283
break;
282284
case Types.SMALLINT:
283-
result.addField(0, (short) (accumulator.getDouble(0) / accumulator.getLong(1)));
285+
result = decimalAvg.shortValue();
284286
break;
285287
case Types.TINYINT:
286-
result.addField(0, (byte) (accumulator.getDouble(0) / accumulator.getLong(1)));
288+
result = decimalAvg.byteValue();
287289
break;
288290
case Types.DOUBLE:
289-
result.addField(0, accumulator.getDouble(0) / accumulator.getLong(1));
291+
result = decimalAvg.doubleValue();
290292
break;
291293
case Types.FLOAT:
292-
result.addField(0, (float) (accumulator.getDouble(0) / accumulator.getLong(1)));
294+
result = decimalAvg.floatValue();
295+
break;
296+
case Types.DECIMAL:
297+
result = decimalAvg;
293298
break;
294-
295299
default:
296300
break;
297301
}
298-
return (T) result.getFieldValue(0);
302+
return (T) result;
299303
}
300304

301305
}

0 commit comments

Comments
 (0)