Skip to content

Commit ce39905

Browse files
committed
More factory methods
1 parent 8ab0931 commit ce39905

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
7878
|final InternalRow $output = new $rowClass($values);
7979
""".stripMargin
8080

81-
ExprCode(code, FalseLiteral, VariableValue(output, classOf[InternalRow]))
81+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow]))
8282
}
8383

8484
private def createCodeForArray(
@@ -110,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
110110
final ArrayData $output = new $arrayClass($values);
111111
"""
112112

113-
ExprCode(code, FalseLiteral, VariableValue(output, classOf[ArrayData]))
113+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData]))
114114
}
115115

116116
private def createCodeForMap(
@@ -131,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
131131
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
132132
"""
133133

134-
ExprCode(code, FalseLiteral, VariableValue(output, classOf[MapData]))
134+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData]))
135135
}
136136

137137
@tailrec

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
292292
|$writeExpressions
293293
""".stripMargin
294294
// `rowWriter` is declared as a class field, so we can access it directly in methods.
295-
ExprCode(code, FalseLiteral, SimpleExprValue(s"$rowWriter.getRow()", classOf[UnsafeRow]))
295+
ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
296296
}
297297

298298
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import scala.language.{existentials, implicitConversions}
2424
import org.apache.spark.sql.types.{BooleanType, DataType}
2525

2626
/**
27-
+ * Trait representing an opaque fragments of java code.
28-
+ */
27+
* Trait representing an opaque fragments of java code.
28+
*/
2929
trait JavaCode {
3030
def code: String
3131
override def toString: String = code
@@ -58,7 +58,14 @@ object JavaCode {
5858
* Create a local java variable.
5959
*/
6060
def variable(name: String, dataType: DataType): VariableValue = {
61-
VariableValue(name, CodeGenerator.javaClass(dataType))
61+
variable(name, CodeGenerator.javaClass(dataType))
62+
}
63+
64+
/**
65+
* Create a local java variable.
66+
*/
67+
def variable(name: String, javaClass: Class[_]): VariableValue = {
68+
VariableValue(name, javaClass)
6269
}
6370

6471
/**
@@ -70,7 +77,14 @@ object JavaCode {
7077
* Create a global java variable.
7178
*/
7279
def global(name: String, dataType: DataType): GlobalValue = {
73-
GlobalValue(name, CodeGenerator.javaClass(dataType))
80+
global(name, CodeGenerator.javaClass(dataType))
81+
}
82+
83+
/**
84+
* Create a global java variable.
85+
*/
86+
def global(name: String, javaClass: Class[_]): GlobalValue = {
87+
GlobalValue(name, javaClass)
7488
}
7589

7690
/**
@@ -82,7 +96,14 @@ object JavaCode {
8296
* Create an expression fragment.
8397
*/
8498
def expression(code: String, dataType: DataType): SimpleExprValue = {
85-
SimpleExprValue(code, CodeGenerator.javaClass(dataType))
99+
expression(code, CodeGenerator.javaClass(dataType))
100+
}
101+
102+
/**
103+
* Create an expression fragment.
104+
*/
105+
def expression(code: String, javaClass: Class[_]): SimpleExprValue = {
106+
SimpleExprValue(code, javaClass)
86107
}
87108

88109
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
310310
toExprCode(s"${value}D")
311311
}
312312
case ByteType | ShortType =>
313-
toExprCode(s"(($javaType)$value)")
313+
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
314314
case TimestampType | LongType =>
315315
toExprCode(s"${value}L")
316316
case _ =>

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan {
111111

112112
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
113113
if (row != null) {
114-
ExprCode.forNonNullValue(VariableValue(row, classOf[UnsafeRow]))
114+
ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow]))
115115
} else {
116116
if (colVars.nonEmpty) {
117117
val colExprs = output.zipWithIndex.map { case (attr, i) =>
@@ -129,7 +129,7 @@ trait CodegenSupport extends SparkPlan {
129129
ExprCode(code, FalseLiteral, ev.value)
130130
} else {
131131
// There are no columns
132-
ExprCode.forNonNullValue(VariableValue("unsafeRow", classOf[UnsafeRow]))
132+
ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
133133
}
134134
}
135135
}

0 commit comments

Comments
 (0)