Skip to content

Commit 5949e6c

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing
## What changes were proposed in this pull request? This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](#16391 (comment)) with cloud-fan. Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method. This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`. The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1923 / 1952 52.0 19.2 1.0X DataFrame 526 / 548 190.2 5.3 3.7X Dataset 3094 / 3154 32.3 30.9 0.6X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1883 / 1892 53.1 18.8 1.0X DataFrame 502 / 642 199.1 5.0 3.7X Dataset 657 / 784 152.2 6.6 2.9X ``` ```java def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ val rdd = spark.sparkContext.range(0, numRows) val ds = spark.range(0, numRows) val func = (l: Long) => l + 1 val benchmark = new Benchmark("back-to-back map", numRows) ... benchmark.addCase("Dataset") { iter => var res = ds.as[Long] var i = 0 while (i < numChains) { res = res.map(func) i += 1 } res.queryExecution.toRdd.foreach(_ => Unit) } benchmark } ``` A motivating example ```java Seq(1, 2, 3).toDS.map(i => i * 7).show ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ Object mapelements_funcResult = null; /* 054 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 055 */ if (mapelements_funcResult == null) { /* 056 */ mapelements_isNull = true; /* 057 */ } else { /* 058 */ mapelements_value = (Integer) mapelements_funcResult; /* 059 */ } /* 060 */ /* 061 */ } /* 062 */ /* 063 */ } /* 064 */ /* 065 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 066 */ /* 067 */ if (mapelements_isNull) { /* 068 */ serializefromobject_rowWriter.setNullAt(0); /* 069 */ } else { /* 070 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 071 */ } /* 072 */ append(serializefromobject_result); /* 073 */ if (shouldStop()) return; /* 074 */ } /* 075 */ } /* 076 */ } ``` Generated code with this PR (lines 48-56 are changed) ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue); /* 054 */ } /* 055 */ /* 056 */ } /* 057 */ /* 058 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 059 */ /* 060 */ if (mapelements_isNull) { /* 061 */ serializefromobject_rowWriter.setNullAt(0); /* 062 */ } else { /* 063 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 064 */ } /* 065 */ append(serializefromobject_result); /* 066 */ if (shouldStop()) return; /* 067 */ } /* 068 */ } /* 069 */ } ``` Java bytecode for methods for `i => i * 7` ```java $ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class Compiled from "Test.scala" public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable { public static final long serialVersionUID; public final int apply(int); Code: 0: aload_0 1: iload_1 2: invokevirtual #18 // Method apply$mcII$sp:(I)I 5: ireturn public int apply$mcII$sp(int); Code: 0: iload_1 1: bipush 7 3: imul 4: ireturn public final java.lang.Object apply(java.lang.Object); Code: 0: aload_0 1: aload_1 2: invokestatic #29 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I 5: invokevirtual #31 // Method apply:(I)I 8: invokestatic #35 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 11: areturn public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5); Code: 0: aload_0 1: invokespecial #42 // Method scala/runtime/AbstractFunction1$mcII$sp."<init>":()V 4: return } ``` ## How was this patch tested? Added new test suites to `DatasetPrimitiveSuite`. Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #17172 from kiszk/SPARK-19008.
1 parent 82138e0 commit 5949e6c

File tree

4 files changed

+208
-9
lines changed

4 files changed

+208
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2929
import org.apache.spark.sql.streaming.OutputMode
3030
import org.apache.spark.sql.types._
31+
import org.apache.spark.util.Utils
3132

3233
object CatalystSerde {
3334
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -211,13 +212,48 @@ case class TypedFilter(
211212
def typedCondition(input: Expression): Expression = {
212213
val (funcClass, methodName) = func match {
213214
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
214-
case _ => classOf[Any => Boolean] -> "apply"
215+
case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType)
215216
}
216217
val funcObj = Literal.create(func, ObjectType(funcClass))
217218
Invoke(funcObj, methodName, BooleanType, input :: Nil)
218219
}
219220
}
220221

222+
object FunctionUtils {
223+
private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = {
224+
dt match {
225+
case BooleanType if isOutput => Some("Z")
226+
case IntegerType => Some("I")
227+
case LongType => Some("J")
228+
case FloatType => Some("F")
229+
case DoubleType => Some("D")
230+
case _ => None
231+
}
232+
}
233+
234+
def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = {
235+
// load "scala.Function1" using Java API to avoid requirements of type parameters
236+
Utils.classForName("scala.Function1") -> {
237+
// if a pair of an argument and return types is one of specific types
238+
// whose specialized method (apply$mc..$sp) is generated by scalac,
239+
// Catalyst generated a direct method call to the specialized method.
240+
// The followings are references for this specialization:
241+
// http://www.scala-lang.org/api/2.12.0/scala/Function1.html
242+
// https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
243+
// SpecializeTypes.scala
244+
// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
245+
// http://axel22.github.io/2013/11/03/specialization-quirks.html
246+
val inputType = getMethodType(inputDT, false)
247+
val outputType = getMethodType(outputDT, true)
248+
if (inputType.isDefined && outputType.isDefined) {
249+
s"apply$$mc${outputType.get}${inputType.get}$$sp"
250+
} else {
251+
"apply"
252+
}
253+
}
254+
}
255+
}
256+
221257
/** Factory for constructing new `AppendColumn` nodes. */
222258
object AppendColumns {
223259
def apply[T : Encoder, U : Encoder](

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.codegen._
3030
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
31+
import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
3132
import org.apache.spark.sql.catalyst.plans.physical._
3233
import org.apache.spark.sql.Row
3334
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
3435
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
35-
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
36+
import org.apache.spark.sql.types._
37+
import org.apache.spark.util.Utils
3638

3739

3840
/**
@@ -219,7 +221,7 @@ case class MapElementsExec(
219221
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
220222
val (funcClass, methodName) = func match {
221223
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
222-
case _ => classOf[Any => Any] -> "apply"
224+
case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
223225
}
224226
val funcObj = Literal.create(func, ObjectType(funcClass))
225227
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)

sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,49 @@ object DatasetBenchmark {
3131

3232
case class Data(l: Long, s: String)
3333

34+
def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
35+
import spark.implicits._
36+
37+
val rdd = spark.sparkContext.range(0, numRows)
38+
val ds = spark.range(0, numRows)
39+
val df = ds.toDF("l")
40+
val func = (l: Long) => l + 1
41+
42+
val benchmark = new Benchmark("back-to-back map long", numRows)
43+
44+
benchmark.addCase("RDD") { iter =>
45+
var res = rdd
46+
var i = 0
47+
while (i < numChains) {
48+
res = res.map(func)
49+
i += 1
50+
}
51+
res.foreach(_ => Unit)
52+
}
53+
54+
benchmark.addCase("DataFrame") { iter =>
55+
var res = df
56+
var i = 0
57+
while (i < numChains) {
58+
res = res.select($"l" + 1 as "l")
59+
i += 1
60+
}
61+
res.queryExecution.toRdd.foreach(_ => Unit)
62+
}
63+
64+
benchmark.addCase("Dataset") { iter =>
65+
var res = ds.as[Long]
66+
var i = 0
67+
while (i < numChains) {
68+
res = res.map(func)
69+
i += 1
70+
}
71+
res.queryExecution.toRdd.foreach(_ => Unit)
72+
}
73+
74+
benchmark
75+
}
76+
3477
def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
3578
import spark.implicits._
3679

@@ -72,6 +115,49 @@ object DatasetBenchmark {
72115
benchmark
73116
}
74117

118+
def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
119+
import spark.implicits._
120+
121+
val rdd = spark.sparkContext.range(1, numRows)
122+
val ds = spark.range(1, numRows)
123+
val df = ds.toDF("l")
124+
val func = (l: Long) => l % 2L == 0L
125+
126+
val benchmark = new Benchmark("back-to-back filter Long", numRows)
127+
128+
benchmark.addCase("RDD") { iter =>
129+
var res = rdd
130+
var i = 0
131+
while (i < numChains) {
132+
res = res.filter(func)
133+
i += 1
134+
}
135+
res.foreach(_ => Unit)
136+
}
137+
138+
benchmark.addCase("DataFrame") { iter =>
139+
var res = df
140+
var i = 0
141+
while (i < numChains) {
142+
res = res.filter($"l" % 2L === 0L)
143+
i += 1
144+
}
145+
res.queryExecution.toRdd.foreach(_ => Unit)
146+
}
147+
148+
benchmark.addCase("Dataset") { iter =>
149+
var res = ds.as[Long]
150+
var i = 0
151+
while (i < numChains) {
152+
res = res.filter(func)
153+
i += 1
154+
}
155+
res.queryExecution.toRdd.foreach(_ => Unit)
156+
}
157+
158+
benchmark
159+
}
160+
75161
def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
76162
import spark.implicits._
77163

@@ -165,9 +251,22 @@ object DatasetBenchmark {
165251
val numRows = 100000000
166252
val numChains = 10
167253

168-
val benchmark = backToBackMap(spark, numRows, numChains)
169-
val benchmark2 = backToBackFilter(spark, numRows, numChains)
170-
val benchmark3 = aggregate(spark, numRows)
254+
val benchmark0 = backToBackMapLong(spark, numRows, numChains)
255+
val benchmark1 = backToBackMap(spark, numRows, numChains)
256+
val benchmark2 = backToBackFilterLong(spark, numRows, numChains)
257+
val benchmark3 = backToBackFilter(spark, numRows, numChains)
258+
val benchmark4 = aggregate(spark, numRows)
259+
260+
/*
261+
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
262+
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
263+
back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
264+
------------------------------------------------------------------------------------------------
265+
RDD 1883 / 1892 53.1 18.8 1.0X
266+
DataFrame 502 / 642 199.1 5.0 3.7X
267+
Dataset 657 / 784 152.2 6.6 2.9X
268+
*/
269+
benchmark0.run()
171270

172271
/*
173272
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -178,7 +277,18 @@ object DatasetBenchmark {
178277
DataFrame 2647 / 3116 37.8 26.5 1.3X
179278
Dataset 4781 / 5155 20.9 47.8 0.7X
180279
*/
181-
benchmark.run()
280+
benchmark1.run()
281+
282+
/*
283+
OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic
284+
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
285+
back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
286+
------------------------------------------------------------------------------------------------
287+
RDD 846 / 1120 118.1 8.5 1.0X
288+
DataFrame 270 / 329 370.9 2.7 3.1X
289+
Dataset 545 / 789 183.5 5.4 1.6X
290+
*/
291+
benchmark2.run()
182292

183293
/*
184294
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -189,7 +299,7 @@ object DatasetBenchmark {
189299
DataFrame 59 / 72 1695.4 0.6 22.8X
190300
Dataset 2777 / 2805 36.0 27.8 0.5X
191301
*/
192-
benchmark2.run()
302+
benchmark3.run()
193303

194304
/*
195305
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
@@ -201,6 +311,6 @@ object DatasetBenchmark {
201311
Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X
202312
Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X
203313
*/
204-
benchmark3.run()
314+
benchmark4.run()
205315
}
206316
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,64 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
6262
2, 3, 4)
6363
}
6464

65+
test("mapPrimitive") {
66+
val dsInt = Seq(1, 2, 3).toDS()
67+
checkDataset(dsInt.map(_ > 1), false, true, true)
68+
checkDataset(dsInt.map(_ + 1), 2, 3, 4)
69+
checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
70+
checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
71+
checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
72+
73+
val dsLong = Seq(1L, 2L, 3L).toDS()
74+
checkDataset(dsLong.map(_ > 1), false, true, true)
75+
checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
76+
checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
77+
checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
78+
checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
79+
80+
val dsFloat = Seq(1F, 2F, 3F).toDS()
81+
checkDataset(dsFloat.map(_ > 1), false, true, true)
82+
checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
83+
checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
84+
checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
85+
checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
86+
87+
val dsDouble = Seq(1D, 2D, 3D).toDS()
88+
checkDataset(dsDouble.map(_ > 1), false, true, true)
89+
checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
90+
checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
91+
8589934593L, 8589934594L, 8589934595L)
92+
checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
93+
checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
94+
95+
val dsBoolean = Seq(true, false).toDS()
96+
checkDataset(dsBoolean.map(e => !e), false, true)
97+
}
98+
6599
test("filter") {
66100
val ds = Seq(1, 2, 3, 4).toDS()
67101
checkDataset(
68102
ds.filter(_ % 2 == 0),
69103
2, 4)
70104
}
71105

106+
test("filterPrimitive") {
107+
val dsInt = Seq(1, 2, 3).toDS()
108+
checkDataset(dsInt.filter(_ > 1), 2, 3)
109+
110+
val dsLong = Seq(1L, 2L, 3L).toDS()
111+
checkDataset(dsLong.filter(_ > 1), 2L, 3L)
112+
113+
val dsFloat = Seq(1F, 2F, 3F).toDS()
114+
checkDataset(dsFloat.filter(_ > 1), 2F, 3F)
115+
116+
val dsDouble = Seq(1D, 2D, 3D).toDS()
117+
checkDataset(dsDouble.filter(_ > 1), 2D, 3D)
118+
119+
val dsBoolean = Seq(true, false).toDS()
120+
checkDataset(dsBoolean.filter(e => !e), false)
121+
}
122+
72123
test("foreach") {
73124
val ds = Seq(1, 2, 3).toDS()
74125
val acc = sparkContext.longAccumulator

0 commit comments

Comments
 (0)