Skip to content

[SPARK-14359] Create built-in functions for typed aggregates in Java #12168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.sql.TypedColumn
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT]
override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
override def finish(reduction: OUT): OUT = reduction

// TODO(ekl) java api support once this is exposed in scala
}


Expand All @@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
override def finish(reduction: Double): Double = reduction

// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}


Expand All @@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction

// Java api support
def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}


Expand All @@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
}
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction

// Java api support
def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}


Expand All @@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}

// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
toColumn(ExpressionEncoder(), ExpressionEncoder())
.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
package org.apache.spark.sql.expressions.java;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.execution.aggregate.TypedAverage;
import org.apache.spark.sql.execution.aggregate.TypedCount;
import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
import org.apache.spark.sql.execution.aggregate.TypedSumLong;

/**
* :: Experimental ::
Expand All @@ -30,5 +36,41 @@
*/
@Experimental
public class typed {
// Note: make sure to keep in sync with typed.scala

/**
* Average aggregate function.
*
* @since 2.0.0
*/
public static<T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) {
return new TypedAverage<T>(f).toColumnJava();
}

/**
* Count aggregate function.
*
* @since 2.0.0
*/
public static<T> TypedColumn<T, Long> count(MapFunction<T, Object> f) {
return new TypedCount<T>(f).toColumnJava();
}

/**
* Sum aggregate function for floating point (double) type.
*
* @since 2.0.0
*/
public static<T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) {
return new TypedSumDouble<T>(f).toColumnJava();
}

/**
* Sum aggregate function for integral (long, i.e. 64 bit integer) type.
*
* @since 2.0.0
*/
public static<T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
return new TypedSumLong<T>(f).toColumnJava();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.expressions.java.typed;
import org.apache.spark.sql.test.TestSQLContext;

/**
Expand Down Expand Up @@ -120,4 +121,52 @@ public Integer finish(Integer reduction) {
return reduction;
}
}

@Test
public void testTypedAggregationAverage() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
new MapFunction<Tuple2<String, Integer>, Double>() {
public Double call(Tuple2<String, Integer> value) throws Exception {
return (double)(value._2() * 2);
}
}));
Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
}

@Test
public void testTypedAggregationCount() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
new MapFunction<Tuple2<String, Integer>, Object>() {
public Object call(Tuple2<String, Integer> value) throws Exception {
return value;
}
}));
Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
}

@Test
public void testTypedAggregationSumDouble() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
new MapFunction<Tuple2<String, Integer>, Double>() {
public Double call(Tuple2<String, Integer> value) throws Exception {
return (double)value._2();
}
}));
Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
}

@Test
public void testTypedAggregationSumLong() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
new MapFunction<Tuple2<String, Integer>, Long>() {
public Long call(Tuple2<String, Integer> value) throws Exception {
return (long)value._2();
}
}));
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
}
}