Skip to content

Commit 0df6a99

Browse files
committed
[SPARK-16489][SQL] checkEvaluation should fail if expression reuses variable names
1 parent b4fbe14 commit 0df6a99

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
132132
expression: Expression,
133133
expected: Any,
134134
inputRow: InternalRow = EmptyRow): Unit = {
135-
135+
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
136+
// some expression is reusing variable names across different instances.
137+
// This behavior is tested in ExpressionEvalHelperSuite.
136138
val plan = generateProject(
137-
GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
139+
GenerateUnsafeProjection.generate(
140+
Alias(expression, s"Optimized($expression)1")() ::
141+
Alias(expression, s"Optimized($expression)2")() :: Nil),
138142
expression)
139143

140144
val unsafeRow = plan(inputRow)
141145
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
142146

143147
if (expected == null) {
144148
if (!unsafeRow.isNullAt(0)) {
145-
val expectedRow = InternalRow(expected)
149+
val expectedRow = InternalRow(expected, expected)
146150
fail("Incorrect evaluation in unsafe mode: " +
147151
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
148152
}
149153
} else {
150-
val lit = InternalRow(expected)
151-
val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit)
154+
val lit = InternalRow(expected, expected)
155+
val expectedRow =
156+
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
152157
if (unsafeRow != expectedRow) {
153158
fail("Incorrect evaluation in unsafe mode: " +
154159
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23+
import org.apache.spark.sql.types.{DataType, IntegerType}
24+
25+
/**
26+
* A test suite for testing [[ExpressionEvalHelper]].
27+
*
28+
* Yes, we should write test cases for test harnesses, in case
29+
* they have behaviors that are easy to break.
30+
*/
31+
class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper {
32+
33+
test("SPARK-16489 checkEvaluation should fail if expression reuses variable names") {
34+
val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) }
35+
assert(e.getMessage.contains("some_variable"))
36+
}
37+
}
38+
39+
/**
40+
* An expression that generates bad code (variable name "some_variable" is not unique across
41+
* instances of the expression.
42+
*/
43+
case class BadCodegenExpression() extends LeafExpression {
44+
override def nullable: Boolean = false
45+
override def eval(input: InternalRow): Any = 10
46+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
47+
ev.copy(code =
48+
s"""
49+
|int some_variable = 11;
50+
|int ${ev.value} = 10;
51+
""".stripMargin)
52+
}
53+
override def dataType: DataType = IntegerType
54+
}

0 commit comments

Comments
 (0)